275 Commits

Author SHA1 Message Date
George Necula
e37727cbce [jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Peter Hawkins
368f3f056e Rollforward of:
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.

This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364611475
2021-03-23 12:00:43 -07:00
Peter Hawkins
7890d6cc2a Rollback of:
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.

This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364599983
2021-03-23 11:12:02 -07:00
Peter Hawkins
f2a6d46426 [JAX] Add an opaque extra_jit_context field to the JAX C++ jit code.
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364563982
2021-03-23 08:35:05 -07:00
Matthew Johnson
af59542d00 Re-applying the changes in #6014, after they had to be rolled-back.
PiperOrigin-RevId: 364200195
2021-03-21 13:40:20 -07:00
jax authors
4f8814a760 Copybara import of the project:
--
bf15ba5310d5f9009571928f70548bcbc7e856c3 by Matthew Johnson <mattjj@google.com>:

don't device transfer in convert_element_type

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
PiperOrigin-RevId: 363995032
2021-03-19 16:35:37 -07:00
Matthew Johnson
bf15ba5310 don't device transfer in convert_element_type
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2021-03-19 13:42:33 -07:00
Lena Martens
d86dd24bf8 Make sublevel weak-referable, and enable the leak checker on sublevels.
Reimplement Sublevel to not inherit from `int`.
See docs on weakref: "CPython implementation detail: Other built-in types
such as tuple and int do not support weak references even when subclassed."
2021-03-18 18:49:41 +00:00
jax authors
6515b5f676 Merge pull request #5977 from apaszke:xmap-with-control-flow
PiperOrigin-RevId: 361854852
2021-03-09 11:22:18 -08:00
Adam Paszke
ec29275d7e Substitute axis names in nested jaxprs
Previously any collectives buried inside control flow would fail to
compile with xmap, because it would not traverse those with its name
substitution. This adds a "catch-all" default substitution rule which
recursively applies to all jaxpr found in the params (at the top level).
2021-03-08 18:11:07 +00:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00
Jake VanderPlas
56687e92e8 Improve error when tracer is used as a list index 2021-02-25 13:35:41 -08:00
jax authors
babf249705 Merge pull request #5717 from google:dynamic-shapes2
PiperOrigin-RevId: 357851603
2021-02-16 18:45:13 -08:00
Peter Hawkins
ff3b402ec0 Improve error messages for invalid JAX types returned by batched functions. 2021-02-16 20:02:11 -05:00
Matthew Johnson
786970130f add dynamic-shape-jaxpr experimental file
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-02-14 13:11:44 -08:00
Matthew Johnson
671fe938bd make map/unmap aval functions extensible
Also, use them in batching.py.

This change is needed for the dynamic shapes prototype in #5717, since
we add new types that can be mapped over.
2021-02-13 08:21:44 -08:00
Matthew Johnson
ca4f7f7964 add check for __jax_array__ method before error
Before raising an error on an unrecognized type, first check if the
object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to
JAX-compatible types.

Implementing this method is not sufficient for a type to be duck-typed
enough for use with jax.numpy. But it may be necessary. That is, someone
trying to add a duck-typed array to be used with JAX identified a need
for __jax_array__ or similar. The user would still need to add lots of
other properties and methods, like dtype and shape attributes.

revives #4725 after it was rolled back. fixes #5356.
2021-02-05 20:30:14 -08:00
James Bradbury
10dcb26cb3 [avals with names] Add named_shape to ShapedArray and update typecompat
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes

PiperOrigin-RevId: 355880632
2021-02-05 10:41:05 -08:00
Adam Paszke
f750969afc Add support for axis names in jax.nn.initializers.variance_scaling
... as well as in a few random functions that it needs (`uniform`,
`normal` and `truncated_normal`). The interface itself doesn't change to
much with the exception of the `shape` arguments of all those functions
now accepting `jax.core.NamedShape` (I didn't move it to be part of the
API just yet, but we can do that any time), which makes it possible to
generate sharded random arrays (in particular the random bits are
different on different shards). I also haven't updated the docstrings,
because I don't know if we're ready to go fully public with this
feature.
2021-02-04 12:38:12 +00:00
Adam Paszke
baf6ed11cf Generalize the access to axis names embedded in primitives
Previously, a few places in our code assumed that all collectives (i.e.
primitives that operate over named axes) keep all of their axes in the
`axis_name` attribute. This was fine for a few simple use cases, but we
are now considering allowing named axes in many more primitives which
can have semantically different attributes where axis names can appear.
2021-01-29 17:31:40 +00:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
Matthew Johnson
9787894d94 refactor batching transform logic, fix leak checks
See PR description in #5492 for details.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2021-01-22 20:17:03 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Peter Hawkins
dd34d48fd1 Fix exception when tokens are used in AD. 2021-01-22 11:00:31 -05:00
Matthew Johnson
84e91d5f1d add transformed fun src info to escaped tracer err
This change adds to the error message when we hit an escaped tracer. In
particular, it adds source info for the function that was transformed.

This change currently only applies to escaped `DynamicJaxprTracer`s
(arising from `jit`, `pmap`, `scan`, and other staging functions) and
not other traces. A natural follow-up would be to attach this
information to other traces.

Co-authored-by: Lena Martens <lenamartens@google.com>
2021-01-20 15:30:37 -08:00
Matthew Johnson
47f7cd4680 avoid printing double periods in error messages 2021-01-18 20:37:12 -08:00
Matthew Johnson
886b26ffeb add source line info to more escaped tracer errors
This extra source info is still only on jaxpr staging tracers, but those
seem to be the most common culprits. I moved the `_line_info` attribute
to the base Tracer class in core.py in anticipation of populating it for
more traces than just DynamicJaxprTrace, but I'll leave that extension
to follow-up.

I adapted the main escaped tracer error messages in core.py, and also
slightly generalized and debugged source_info_util functions (thanks for
explaining the path prefix bug, @froystig !).
2021-01-18 19:00:04 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Matthew Johnson
cdc1b0546a remove AbstractValue.at_least_vspace default impl 2020-12-29 11:43:44 -08:00
Lena Martens
d1cdd7756c Fix UnexpectedTracer omnistaging error. 2020-12-19 17:46:08 +01:00
jax authors
7c294e62f4 Copybara import of the project:
--
7342318774c6f1195f0e238f1209425109ea8944 by Matthew Johnson <mattjj@google.com>:

check for __jax_array__ method for conversion

--
6742016382b0511f5ac9ec21f67d2122a9f37cb7 by Matthew Johnson <mattjj@google.com>:

fix typo

--
5eb36855e53d8d4e81e281d08dc9264d2671f21f by Matthew Johnson <mattjj@google.com>:

ensure some jnp funs duck-type with __jax_array__

PiperOrigin-RevId: 347763582
2020-12-15 23:13:29 -08:00
Matthew Johnson
7342318774 check for __jax_array__ method for conversion 2020-12-14 17:09:25 -08:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00
jax authors
c7057d5fb1 Merge pull request #5005 from apaszke:xmap-primitive
PiperOrigin-RevId: 344263137
2020-11-25 09:07:27 -08:00
Adam Paszke
5ee2de1675 Forbid pmap/soft_pmap/sharded_jit inside xmap 2020-11-25 13:47:05 +00:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Adam Paszke
2494e0c339 Add XLA lowering for xmap
This should allow us to try out xmap not only in a simulation (i.e.
faking the devices using vmap, which we still support), but also on real
hardware.

Limitations:
- No compilation caching yet
- Nested xmaps not supported yet
- Transforms (AD, vmap, etc.) of xmaps not supported yet

Benefits:
- An xmap over multiple mesh axes already implements a more efficient
  lowering than the one used for nested pmaps.

The `resources` context-manager is now called `fake_resources`, while
real meshes can be defined in a specific context using the
`mesh(devices, axis_names)` manager. `devices` is supposed to be an
`ndarray` of JAX device objects (e.g. obtained from `jax.devices()`),
while `axis_names` should be a tuple of length matching the rank of
`devices` and specifying mesh axis names.

For concrete examples see the changes in `gmap_tests.py`.

In principle the current version of the code should also work in a
multi-host setting, but I haven't tested it just yet.
2020-11-24 11:13:49 +00:00
Peter Hawkins
84c723fc9e [JAX] Move pprint_util into jax._src.
PiperOrigin-RevId: 343279975
2020-11-19 06:42:19 -08:00
jax authors
69c920c601 Merge pull request #4796 from qiuminxu:add_jax_named_call
PiperOrigin-RevId: 342951787
2020-11-17 14:52:52 -08:00
Roy Frostig
78c6e4e5e5 fix check_jaxpr docstring 2020-11-13 18:00:33 -08:00
jax authors
83a38f4f3b Merge pull request #4854 from j-towns:tidy-stack
PiperOrigin-RevId: 342264328
2020-11-13 08:07:45 -08:00
Matthew Johnson
8b006f6a90 add correct annotations to core.TraceStack 2020-11-13 07:23:02 -08:00
Qiumin Xu
0f8ea37556 Update core.py 2020-11-12 17:36:46 -08:00
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00