90 Commits

Author SHA1 Message Date
jax authors
fc7775e1d1 Merge pull request #7968 from hawkinsp:partial
PiperOrigin-RevId: 398025545
2021-09-21 10:21:13 -07:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
Peter Hawkins
58c7ee46bc Remove jax.util.partial. 2021-09-20 20:32:49 -04:00
Peter Hawkins
f35ab3693d Remove jax.partial from the JAX API.
Use functools.partial instead.
2021-09-20 09:19:53 -04:00
jax authors
f47926a23d Merge pull request #7940 from hawkinsp:api
PiperOrigin-RevId: 397319298
2021-09-17 07:58:17 -07:00
Jake VanderPlas
9a2697437e Update changelog for several recent PRs 2021-09-16 14:10:08 -07:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jake VanderPlas
abeeb48ba1 jnp.array: raise TypeError on boolean scalar indices 2021-09-15 12:50:44 -07:00
Jake VanderPlas
404e22ec67 Add Changelog for jax v0.2.21 development 2021-09-15 12:10:30 -07:00
Peter Hawkins
b56c2ccadd Remove export of jax.lax.partial. 2021-09-14 16:17:50 -04:00
yashkatariya
765746b60e update version and changelog for pypi 2021-09-02 15:38:47 -07:00
yashkatariya
14a02c6880 Remove new features 2021-09-01 11:26:41 -07:00
yashkatariya
84edde2f9b Add new features section 2021-09-01 10:56:54 -07:00
yashkatariya
be824a792e Update files after new jaxlib release 0.1.71 2021-09-01 10:43:20 -07:00
Jake VanderPlas
fb30fa852d update CHANGELOG for #7662 & #7732 2021-08-27 16:43:58 -07:00
Matthew Johnson
6f7be1fad9 update version and changelog for pypi 2021-08-12 21:17:53 -07:00
Peter Hawkins
beddf598bd Add @jit decorators to jax.numpy operators.
By wrapping common operators in `jit`, we get a number of benefits:
* `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive.
* `jit` allows us to cache and reuse logic such as broadcasting and type promotion.

One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars.

On my laptop, this benchmark improves from 95us to 4us:

```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jax.device_put(7)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [3]: %timeit jnp.add(x, x).block_until_ready()
4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

PiperOrigin-RevId: 389871450
2021-08-10 06:49:28 -07:00
Yash Katariya
bf967d88d8 Upgrade versions after jaxlib release
PiperOrigin-RevId: 389753047
2021-08-09 16:37:44 -07:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
Skye Wanderman-Milne
a7916f1428 Bump jax version and CHANGELOG to 0.2.18 2021-07-21 11:56:24 -07:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
George Necula
a21683605d [host_callback] Increase number of threads for callback processing.
Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.

One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.

With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.

PiperOrigin-RevId: 385493070
2021-07-19 00:18:06 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Qiao Zhang
82e74959fe Update changelog for jaxlib-0.1.69. 2021-07-12 12:06:41 -07:00
George Necula
0beef34d25 [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
2021-07-12 01:11:42 -07:00
Peter Hawkins
b393d9a8c1 Update jax version and changelog for 0.1.27.
Disable tfrt CPU backend on jaxlib 0.1.68 to work around https://github.com/google/jax/issues/7229.
2021-07-09 15:21:52 -04:00
James Bradbury
8e86952ee4 AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
2021-07-08 12:06:29 -07:00
tlu7
d97b393694 Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-07-02 10:42:29 -07:00
Qiao Zhang
61ab59c40a Update changelog for jax and jaxlib releases. 2021-06-28 13:52:19 -07:00
Skye Wanderman-Milne
3da8a4cd86 Update jax version to 0.2.16 2021-06-23 14:45:45 -07:00
Skye Wanderman-Milne
444ee5e840 Update jax version to 0.2.15 2021-06-23 11:55:40 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
George Necula
1994f6df4a [jax2tf] Fix the round-trip call_tf(convert)
Also cleaned the handling of global state in jax2tf.
2021-06-11 11:57:27 +03:00
Skye Wanderman-Milne
063401f3ef Update jax version to 0.2.14 2021-06-10 13:15:53 -07:00
George Necula
59ae45a83c [jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.

In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.

For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.

The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.

I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.

For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-06-09 08:08:42 +02:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
jax authors
ecab743e5c Merge pull request #6877 from hawkinsp:tracebacks
PiperOrigin-RevId: 377247694
2021-06-03 02:47:21 -07:00
George Necula
d03d849a19 [jax2tf] Fix the 32/64-bit behavior to follow JAX rules
JAX and TensorFlow have different behavior w.r.t. 32-64 bit
computations. This PR cleans up the handling of types in jax2tf
to ensure that we follow the same behavior in jax2tf and in JAX.

This means that f_jax(args) always does the computation with the
same precision as jax2tf.convert(f_jax)(args). This may mean that
the result of the conversion depends on the value of JAX_ENABLE_x64.

See README.md for more details.
2021-06-03 10:12:58 +03:00
Peter Hawkins
2882286b50 Add a --jax_traceback_filtering flag to control the traceback filtering mode.
Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
2021-06-02 16:25:37 -04:00
jax authors
8e6101c6a1 Merge pull request #6866 from gnecula:tf_pjit
PiperOrigin-RevId: 376989780
2021-06-01 22:50:12 -07:00
George Necula
2ad9c0c34c [jax2tf] Fix the scoping of the enable_xla conversion parameter
Previously, the global enable_xla flag was set upon entry to
`jax.convert`. It should instead be set only for the duration
of the just-in-time conversion, which may happen later when
the converted function is invoked.
2021-05-21 11:22:21 +03:00
Peter Hawkins
f83e309fe7 Update changelog. 2021-05-12 09:46:17 -04:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Qiao Zhang
528d5bbb11 Update README etc for jaxlib 0.1.66 release. 2021-05-11 16:49:32 -07:00
jax authors
c31943cfe5 Merge pull request #6622 from hawkinsp:eightr
PiperOrigin-RevId: 372035283
2021-05-04 18:17:56 -07:00