56 Commits

Author SHA1 Message Date
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
76ef28b2d3 Remove jax.util.partial. 2021-09-28 11:04:02 -04:00
Peter Hawkins
58c7ee46bc Remove jax.util.partial. 2021-09-20 20:32:49 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Nathan Howell
2f3e1ad1ab Export split_dict from jax.util 2021-01-29 11:13:36 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
28f479c571 alias jax.partial to functools.partial 2021-01-08 11:52:29 -08:00
Peter Hawkins
9dd1704aef Add types to axes in lax_numpy. 2021-01-06 16:28:29 -05:00
Adam Paszke
8fcacd645c Support mapping a single logical axis to multiple mesh axes in xmap 2020-12-11 14:35:31 +00:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Adam Paszke
94fcd7f2ab Use taggedtuple instead of namedtuple when defining ShardingSpecs
Because namedtuples don't take the class into account when comparing for
equality!
2020-11-24 14:48:11 +00:00
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00
jax authors
bdd7915661 Internal change
PiperOrigin-RevId: 341644256
2020-11-10 10:12:27 -08:00
Adam Paszke
6914058cbe Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 13:35:23 +00:00
Adam Paszke
8c1fbdc901 Make ShardingSpec more flexible
In preparation of adding support for `in_axes` and `out_axes` to `pmap`.

The only difference in expressivity of the new approach is that the
sharded dimensions can be permuted before ordering/replicating the
indices to match the device assignment. This is necessary if we want to
support `in_axes`, because it may cause some sharded dimensions that are
supposed to get mapped to the "replication" XLA mesh axis to follow the
dimensions mapped to the "partitioning" XLA mesh axis. XLA fixes the
mesh order such that the replicated dimension is always the leading one,
which forces us to decouple the order of data dimensions from the mesh
dimensions.

This patch additionally folds the `is_axis_materialized` into the
sharding specification, by wrapping the integers in small ADT-like
wrappers that distinguish the different ways of partitioning dimensions.
The order of replication is also more explicit in the `mesh_mapping`,
as opposed to being represented as a list of replication factors to be
inserted into the sharding details to obtain a mesh mapping.

Note that this doesn't change any existing functionality. It is purely
an internal rewrite that is supposed to lay the groundwork for the next
patches.
2020-11-09 20:00:39 +00:00
Roy Frostig
fa1d7ab5fa move wraps from api_util to util to avoid cyclic dependencies 2020-10-26 12:31:19 -07:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
jax authors
04fa89a12c Merge pull request #4299 from zhangqiaorjc:segsum
PiperOrigin-RevId: 332964956
2020-09-21 16:48:00 -07:00
Qiao Zhang
35d231990c Add ceil_of_ratio util and bucket_size TODO. 2020-09-21 14:45:37 -07:00
Jake VanderPlas
6b9dfb1396 fix incorrect indentation 2020-09-18 13:06:58 -07:00
Peter Hawkins
e06a6ab6bf
Add support for negative axes to vmap. (#4111)
* Add support for negative axes to vmap.

* Add workaround for out-of-range vmap axes.
2020-08-24 20:21:19 -04:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Jake Vanderplas
b813ae3aff
Cleanup: record names in get_module_functions (#3697) 2020-07-08 14:44:49 -07:00
Roy Frostig
fc4ab77bc6 merge constvars when forming cond branch jaxprs 2020-05-13 21:14:41 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00
Peter Hawkins
681ba37f7e
Drop fastcache dependency, which isn't necessary on Python 3. (#1995)
Drop protobuf and six dependencies from travis configuration.
2020-01-14 10:08:23 -05:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Neeraj Pradhan
ed309b6f6a Fix for compatibility with python>=3.7 and numpy>=1.18 2020-01-01 08:17:55 -08:00
Jamie Townsend
f66aa275a6 Rm duplicates from end_nodes in toposort 2019-10-01 17:56:44 +01:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
a94600c407 Optimizations to pmap computation launch time. 2019-08-05 14:08:46 -04:00
Peter Hawkins
08013954a4 Use fastcache for LRU caches in JAX.
fastcache is both a faster cache implementation and is also thread-safe.
2019-07-22 17:24:10 -04:00
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Matthew Johnson
5aea10c7b3 make static_argnums cache on value when possible
fixes #691
2019-05-09 20:00:24 -07:00
Matthew Johnson
25479c3518 fix pmap performance bug, dont always copy to host 2019-03-26 08:35:34 -07:00
Peter Hawkins
1800d6554d Add comments to linear_util.py. 2019-03-12 15:07:52 -04:00
Peter Hawkins
2b383bdbd9 Increase cache size to 4096. 2019-01-15 10:32:58 -05:00
Peter Hawkins
3266bb3122 Change linear_util.memoize to use an LRU cache.
Add util.OrderedDict that retrofits a move_to_end method onto Python 2 OrderedDicts.
2019-01-14 21:48:28 -05:00
Peter Hawkins
05b1049e49 Change util.memoize to be an LRU cache with a default size of 64 entries.
The goal is to limit peak memory when running a large number of computations, e.g., the test suite.
2019-01-14 20:11:08 -05:00
Peter Hawkins
8127392a18 Use get() rather than a try-catch block in memoized function lookup.
Currently backtraces often look like this:
```
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
    133     try:
--> 134       return cache[key]
    135     except KeyError:

KeyError: ((lu, ShapedArray(int32[2,2])), ())

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
~/p/jax/jax/util.py in memoized_fun(*args, **kwargs)
    133     try:
--> 134       return cache[key]
    135     except KeyError:

KeyError: ((lu, xla_client.Shape(_dtype=dtype('int32'), _dimensions=(2, 2), _is_tuple=False, _minor_to_major=None)), ())

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-26-d6c00d50e3c9> in <module>
```

The "during handling of the above exception..." message is mostly a distraction for the user that occurs because we perform the memoized function evaluation inside a `catch` block. By performing the function evaluation outside the catch block, we can get better backtraces without the distraction of the KeyError exception.

```
2018-12-21 13:32:56 -05:00
Alex Wiltschko
763f860135 Automated detection of unimplemented functions 2018-12-11 11:52:31 -05:00
Dougal Maclaurin
30124b6da1 Added jit transformations to generated functions. Fixed bug in comparing numpy arrays for equality. 2018-12-08 00:03:34 -05:00