38 Commits

Author SHA1 Message Date
Jake VanderPlas
b49c75c0d7 [x64] make jax.experimental.loops consistent with default dtype 2021-12-08 12:08:49 -08:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
Lena Martens
19ee7b22e1 Expose UnexpectedTracerError and add docs. 2021-07-27 23:23:28 +01:00
Jake VanderPlas
c45acd70a8 Cleanup: use pep 448 unpacking to simplify some code 2021-07-12 16:30:53 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
a0b12bba25 DOC: fix minor formatting issues 2021-01-20 14:38:19 -08:00
George Necula
555a215cfb [loops] Extend loops with support for pytrees
Also improve error checking and error messages.
2021-01-14 21:17:14 +02:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
c9d8acd2e9
put core trace state in a threading.local class (#3869)
this is a refinement of the fix in #3845, so that we no longer need
TraceState.set_state (and so that #3370 is easier to adapt)
2020-07-26 22:38:14 -07:00
Roy Frostig
8a62a9b654
block-unrolled scan primitive implementation (#3738)
* block-unrolled scan implementation, via optional `_unroll` scan parameter

* index statically in the inlined path of lax.scan

* make `unroll` a required scan parameter, and test that it unrolls
2020-07-15 14:00:50 -04:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Roy Frostig
dc4c9f0450 change cond primitive to an indexed conditional with multiple branch functions
in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
joao guilherme
77e4d8b3b9
Updates onp -> np in random, loops, jet and in the tests of stax and optix (#3182) 2020-05-21 14:12:18 -07:00
Roy Frostig
efc1104cde have loops module generate same-argument jaxprs for single-operand cond 2020-05-13 21:14:41 -07:00
Matthew Johnson
3cd409ee88
add optional 'forward' argument to lax.scan (#2921)
* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>
2020-05-04 19:44:22 -07:00
George Necula
abbc70b20a Added type annotations and comments related to partial evaluation.
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:

 * instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
 * instead of PartialVal((None, pval)) we use PartialVal.known(pval)

Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
2020-04-09 13:00:33 +03:00
George Necula
d2a827a08a Ensure the global trace_state is restored on errors in loops
This is an attempted fix for https://github.com/google/jax/issues/2507
2020-04-01 10:23:14 +03:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
George Necula
deb21ef15d Expanded the error messages due to re-using tracers saved in global state.
Previously these errors were raising Exception (as other internal errors),
but these errors may arise out of mis-use of tracers.
2020-02-15 06:35:49 +01:00
George Necula
ae3003e9d4 Simplify bound_subjaxprs.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.

We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)

This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
2020-02-06 09:34:53 +01:00
Roy Frostig
664a4e123d
VJP of cond, via partial eval + transpose (#2091)
VJP (grad) of lax.cond, via partial eval + transpose


Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-30 15:03:00 -08:00
Matthew Johnson
96102dc727
simplify cond by removing consts (#2102)
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.

However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.

Co-authored-by: Roy Frostig <frostig@google.com>
2020-01-29 13:17:39 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
George Necula
8ec6ea4742 Implemented suggestions from code review.
* added example of while_range to the module docstring.
* wrap the very long lines
2019-11-18 11:39:58 +01:00
George Necula
d549d44e43 Improved documentation
Also fix for the Python 2 iterators.
2019-11-16 18:36:08 +01:00
George Necula
64e186c337 Fix tests for Python 2 and for X64 2019-11-16 18:05:45 +01:00
George Necula
d24c374d59 An implementation of an experimental syntactic sugar for 'for' loops.
See description in jax/experimental/loops.py.
2019-11-16 17:23:40 +01:00