374 Commits

Author SHA1 Message Date
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
jax authors
67723da38b Merge pull request #9143 from mattjj:fix-jaxpr-checking-error-messages
PiperOrigin-RevId: 420866862
2022-01-10 15:09:04 -08:00
Matthew Johnson
3548e023ec fix jaxpr type checking error messages
The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.

This commit correctly threads through the context so as to provide
error messages with coherent variable names.
2022-01-09 20:07:58 -08:00
Matthew Johnson
73b530aead Rolling forward again...
PiperOrigin-RevId: 420551242
2022-01-08 22:46:06 -08:00
jax authors
c335dfcc2a Re-applying #9136 after it was rolled back.
PiperOrigin-RevId: 420548165
2022-01-08 22:08:44 -08:00
Matthew Johnson
b0dabab99c Re-applying #9136 after it was rolled back.
PiperOrigin-RevId: 420545623
2022-01-08 21:39:41 -08:00
jax authors
0e201425e6 Internal change
PiperOrigin-RevId: 420436636
2022-01-07 23:27:47 -08:00
Matthew Johnson
c3b1d0dfd0 simpler jaxpr eqn params to bind params conversion
Final-style higher-order primitives, like call_p, xla_call_p (underlying
jit), xla_pmap_p (underlying pmap), and xmap_p (underlying xmap) have
slightly different bind signatures (while tracing) from their signatures
when they appear in jaxprs. In particular, their trace-time binds are
parameterized by a Python callable (or really a lu.WrappedFun)
representing the function to be applied, while in jaxpr eqns they are
parameterized by a jaxpr representing the same.

As a result, to round-trip from jaxpr to Python traceable, in
core.eval_jaxpr we have to convert from one parameter signature to the
other. (Basically we had to take the jaxpr and turn it into a Python
callable, via lu.wrap_init(partial(core.eval_jaxpr, call_jaxpr, ...)).)

However due to historical path dependence these conversion mechanisms
were all slightly distinct and kind of a mess. There was a case analysis
for call_jaxpr and map_jaxpr in core.eval_jaxpr_eqn (a helper function
created only because of this complexity), and there was a separate table
only used for the xmap rule.

In this PR we uniformized things! We basically only have a table (to
simplify core.eval_jaxpr), but instead of having it as a table we just
attached the rules to the different primitive classes (CallPrimitive,
MapPrimitive, and XmapPrimitive) to make things less error-prone (we
have a few different CallPrimitive instantiations, like call_p,
xla_call_p, named_call_p, and remat_call_p, and this way we don't have
to remember to populate the table separately for each).

This was actually a warmup simplification before we attempt to simplify
custom derivatives (to unify custom_jvp_call_p and
custom_jvp_call_jaxpr_p).

Co-authored-by: Roy Frostig <frostig@google.com>
2022-01-07 21:37:36 -08:00
jax authors
f801bb293c Merge pull request #8955 from mattjj:djax-now-2
PiperOrigin-RevId: 420123832
2022-01-06 12:52:35 -08:00
Matthew Johnson
6ce38acca8 remove axis name logic from Primitive / bind
Instead, just give AxisPrimitive its own bind function. This way the
logic is nicely separated by concerns. In addition, this factorization
will let us more easily experiment with other ways to find the top trace
(e.g. for assert_p in checkify).
2022-01-05 14:16:37 -08:00
Matthew Johnson
4db899007b add staging logic for polymorphic shapes in jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-01-05 14:11:12 -08:00
George Necula
3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00
Peter Hawkins
0c169764ed Use .__mro__ instead of .mro() when enumerating superclasses of a type.
mro() has a different signature on metaclasses, but __mro__ is a cached tuple property that appears to have the same signature everywhere. As far as I can tell, it always exists.

PiperOrigin-RevId: 416410647
2021-12-14 15:36:25 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Matthew Johnson
659f8b794f add skeleton checkify transformation 2021-12-01 10:44:58 -08:00
Lena Martens
cb6a3f216f Leak checker: garbage collect before collecting hanging references. 2021-11-30 14:02:51 +00:00
jax authors
f196f3780a Merge pull request #8667 from mattjj:custom-pp-eqn-rules
PiperOrigin-RevId: 412140658
2021-11-24 15:01:24 -08:00
Matthew Johnson
3d16a32986 add option for enabling custom jaxpr pprint rules 2021-11-24 14:31:58 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
839d410de0 [MLIR] Move most MLIR translation rules into lax.
PiperOrigin-RevId: 411942327
2021-11-23 18:58:28 -08:00
Matthew Johnson
8430deda3e custom pp_eqn rules, simpler xla_call print 2021-11-23 15:52:52 -08:00
Jake VanderPlas
496e400c71 [x64] Make autodiff respect weak types 2021-11-23 15:04:08 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
jax authors
f08a5a07a8 Merge pull request #8552 from mattjj:elide-more-convert-element-types
PiperOrigin-RevId: 411082070
2021-11-19 09:44:30 -08:00
Matthew Johnson
abbf78b5c3 generalize jaxpr simplification machinery
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
2021-11-19 09:00:59 -08:00
George Necula
3715fcb930 Added workaround for bug in XLA 2021-11-18 11:01:50 +02:00
George Necula
75155f5eda [shape_poly] Refactor arange and image_resize for shape polymorphism
Bug: 8367

Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
2021-11-18 10:27:32 +02:00
Jake VanderPlas
b472ac3c46 jax_check_tracer_leaks: add warning about debuggers 2021-11-08 09:21:18 -08:00
Sharad Vikram
32319e1bc3 Fix forward for PR #8392 (made source_info for new_jaxpr_eqn argument optional again)
PiperOrigin-RevId: 406466709
2021-10-29 15:50:14 -07:00
jax authors
2ab00151ed Copybara import of the project:
--
b40245e38d7837a7777735ad60f3b5b1ac2d499d by Sharad Vikram <sharad.vikram@gmail.com>:

Use `SourceInfo` named tuple to keep track of source information

PiperOrigin-RevId: 406293469
2021-10-28 23:07:56 -07:00
Sharad Vikram
b40245e38d Use SourceInfo named tuple to keep track of source information 2021-10-28 13:31:26 -07:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
6a45a9236d Remove the _num_buffers attribute from core.AbstractValue.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
2021-10-13 14:35:07 -04:00
Roy Frostig
9a182e66c8 order-independent hash in core.NamedShape 2021-10-12 15:53:44 -07:00
Matthew Johnson
482e41d796 remove ShapedArray.__len__
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
2021-10-07 22:04:16 -07:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
d4023508a4 Uniquify variable names globally within a jaxpr.
It is confusing when the same name is shadowed within an inner lambda expression. Use globally unique variable names in each pretty-printed jaxpr.
2021-10-01 12:49:47 -04:00
Peter Hawkins
ef560fb177 Print long variable lists more compactly. 2021-09-28 10:01:51 -04:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Adam Paszke
c845d15b3a Cache used_axis_names calls for jaxprs
All control-flow primitives are `AxisPrimitive`s now, which means that we're doing
lots of those used-names traversals during dispatch and they can be expensive!
This adds caching to try to lower the cost of that change.

PiperOrigin-RevId: 395921887
2021-09-10 07:10:07 -07:00
Sharad Vikram
cc3e197991 Combine initial_style_batchers with collective_rules 2021-09-09 11:23:51 -07:00
Adam Paszke
1158530faa Remove axis name from named_shape when unmapping avals
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.

PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
Adam Paszke
0636f490f3 Ensure that named axes consistently refer to global axis sizes in xmap
Fixes #6959.

PiperOrigin-RevId: 395210686
2021-09-07 03:26:21 -07:00
jax authors
cc1cc98d82 Merge pull request #7783 from shoyer:set-item-errors
PiperOrigin-RevId: 394442094
2021-09-02 06:02:56 -07:00
Stephan Hoyer
d204325c1f Don't refer to deprecated jax.ops.index_update in error messages
I've also updated the docs for ``jax.ops`` to note that ``at[].set()``
is guaranteed to be performed in-place under JIT. Someone who knows XLA
well should double check that fact!
2021-09-01 20:43:13 -07:00
Matthew Johnson
8ae1245c21 add assertions 2021-08-30 11:10:10 -07:00
Matthew Johnson
83f95a5dae custom_jvp/vjp tweaks and fixes 2021-08-17 17:51:35 -07:00
Markus Kunesch
6708cd3158 Add dtype to string representation of ConcreteArray.
The string representation of ConcreteArray did not include the data type of the
wrapped value. This makes it harder to spot the reason for errors arising from
inconsistent values (issue #5364). This commit adds the data type to the string
representation of ConcreteArray.
2021-08-13 15:01:26 +00:00