343 Commits

Author SHA1 Message Date
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
6a45a9236d Remove the _num_buffers attribute from core.AbstractValue.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
2021-10-13 14:35:07 -04:00
Roy Frostig
9a182e66c8 order-independent hash in core.NamedShape 2021-10-12 15:53:44 -07:00
Matthew Johnson
482e41d796 remove ShapedArray.__len__
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
2021-10-07 22:04:16 -07:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
d4023508a4 Uniquify variable names globally within a jaxpr.
It is confusing when the same name is shadowed within an inner lambda expression. Use globally unique variable names in each pretty-printed jaxpr.
2021-10-01 12:49:47 -04:00
Peter Hawkins
ef560fb177 Print long variable lists more compactly. 2021-09-28 10:01:51 -04:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Adam Paszke
c845d15b3a Cache used_axis_names calls for jaxprs
All control-flow primitives are `AxisPrimitive`s now, which means that we're doing
lots of those used-names traversals during dispatch and they can be expensive!
This adds caching to try to lower the cost of that change.

PiperOrigin-RevId: 395921887
2021-09-10 07:10:07 -07:00
Sharad Vikram
cc3e197991 Combine initial_style_batchers with collective_rules 2021-09-09 11:23:51 -07:00
Adam Paszke
1158530faa Remove axis name from named_shape when unmapping avals
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.

PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
Adam Paszke
0636f490f3 Ensure that named axes consistently refer to global axis sizes in xmap
Fixes #6959.

PiperOrigin-RevId: 395210686
2021-09-07 03:26:21 -07:00
jax authors
cc1cc98d82 Merge pull request #7783 from shoyer:set-item-errors
PiperOrigin-RevId: 394442094
2021-09-02 06:02:56 -07:00
Stephan Hoyer
d204325c1f Don't refer to deprecated jax.ops.index_update in error messages
I've also updated the docs for ``jax.ops`` to note that ``at[].set()``
is guaranteed to be performed in-place under JIT. Someone who knows XLA
well should double check that fact!
2021-09-01 20:43:13 -07:00
Matthew Johnson
8ae1245c21 add assertions 2021-08-30 11:10:10 -07:00
Matthew Johnson
83f95a5dae custom_jvp/vjp tweaks and fixes 2021-08-17 17:51:35 -07:00
Markus Kunesch
6708cd3158 Add dtype to string representation of ConcreteArray.
The string representation of ConcreteArray did not include the data type of the
wrapped value. This makes it harder to spot the reason for errors arising from
inconsistent values (issue #5364). This commit adds the data type to the string
representation of ConcreteArray.
2021-08-13 15:01:26 +00:00
jax authors
17a606a95d Merge pull request #7556 from LenaMartens:main
PiperOrigin-RevId: 389876971
2021-08-10 07:20:35 -07:00
lenamartens
ddaef095bb Reword UnexpectedTracerError and add visual dividers. 2021-08-10 14:31:40 +01:00
Peter Hawkins
a3a2ed6206 Strip debug_info and jaxpr_stack from MainTrace instances used as C++ JIT cache keys.
Without this fix, we created reference count cycles via the C++ JIT cache.
2021-08-09 16:05:02 -04:00
Jake VanderPlas
63a788b4de Cleanup: switch to new version of super() 2021-08-05 13:11:07 -07:00
George Necula
0c1a37ce33 [jax2tf] Add shape polymorphism support for jnp.eye 2021-08-03 09:19:49 +03:00
botev
69fcc0c20d Abstracts into a separate function to evaluation of a single jaxpr equation. 2021-08-01 18:39:51 +01:00
Matthew Johnson
c31688d2d1 fix cond-of-pmap bug 2021-07-29 10:34:43 -07:00
Lena Martens
2190734637 Add tracers to LeakChecker error, and filter out false positives this way.
If we can't find any hanging tracers in the gc.get_referrers chain, is it
really a leak? Probably not!
2021-07-29 15:45:24 +01:00
Lena Martens
19ee7b22e1 Expose UnexpectedTracerError and add docs. 2021-07-27 23:23:28 +01:00
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
Roy Frostig
258ae44303 refine constant-hoisting heuristic for closure_convert
Instead of hoisting all float-type arrays during closure conversion,
only hoist JVPTracers (or tracers carrying such tracers
indirectly). Doing so better approximates the subset of
closure-captured values that participate in AD.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-07-26 18:50:45 -07:00
Lena Martens
24c9a933d6 Add shape and dtype of leaked tracer to UnexpectedTracerError. 2021-07-21 17:50:44 +01:00
Adam Paszke
d25f4b34b8 Add an option to strictly enforce sharding implies by named axes
At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.

PiperOrigin-RevId: 385562158
2021-07-19 08:39:27 -07:00
George Necula
7e335e0e2e [jax2tf] Fix conversion of gradients for shape polymorphic functions.
This fixes the case when the primal shape polymorphic function has
output shapes that are polynomials of the input shapes (not just
dimension variables).
2021-06-23 11:20:11 +02:00
Qiao Zhang
57234e7eaa Fix typos and indent. 2021-06-16 11:10:42 -07:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
Matthew Johnson
d21e8c0657 handle case where trace debug_info is None 2021-05-06 18:38:20 -07:00
jax authors
3c6a41eb9c Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
2021-05-05 14:45:57 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
jax authors
e8f209c775 Merge pull request #6584 from google:jit-inline-2
PiperOrigin-RevId: 371541392
2021-05-01 21:53:34 -07:00
Adam Paszke
893b5a09ea Make pjit into an initial style primitive
Partly to make it more robust (e.g. we no longer need to implement
post_process_call), partly beacuse it is not really a call primitive
(it modifies the argument and return avals in a multiprocess mesh),
and partly as an experiment to see how difficult would it be to actually
make it more autodidax-like.

Overall, it seems like a mixed bag, maybe slightly positive. The thunks
are gone which is nice, but one has to be much more careful when dealing
with avals.

PiperOrigin-RevId: 371352737
2021-04-30 09:57:36 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Adam Paszke
23f847e0d3 Make the initial-style -> final-style conversion rule based
Also, add a rule for pjit to make sure that we can eval jaxprs that
contain pjits.

PiperOrigin-RevId: 369964136
2021-04-22 15:30:30 -07:00
Adam Paszke
454f5e67b1 Enable nesting xmaps in SPMD lowering mode
This uses the recently added ability to modify the `BatchTrace` to add a new
`SPMDBatchTrace`, which additionally fills in `spmd_in_axes` and `spmd_out_axes`
of xmap primitives. Those fields are necessary, because XLA does not allow us to
emit partial `OpSharding` annotations, meaning that we have to track where the
positional axes of outer xmaps are inserted at the boundaries of inner xmaps.
Otherwise, XLA could misinterpret our intention and unnecessarily force
replication along the mesh axes used by the outer xmaps.

PiperOrigin-RevId: 369831571
2021-04-22 02:47:07 -07:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
Adam Paszke
828e210601 Add a type checking rule for xmap
Also fix the type checking code in core, which incorrectly propagated output
avals of custom type checking rules.

PiperOrigin-RevId: 369485371
2021-04-20 11:43:33 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00