rocm_jax/CHANGELOG.md

928 lines
41 KiB
Markdown
Raw Normal View History

2021-03-05 11:07:50 -08:00
# Change log
Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
<!--
Remember to align the itemized text with the first line of an item within a list.
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
2021-03-05 11:07:50 -08:00
-->
## jaxlib 0.1.73 (Unreleased)
## jax 0.2.22 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.21...main).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.
Unhashable static arguments have long been disallowed on `jax.jit`, but they
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
arguments using object identity.
This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
wants to compare static arguments by object identity, they can define
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use `functools.partial` to encapsulate the
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
* Deprecations
* The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are
deprecated and will be removed in a future JAX release. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a
`DeprecationWarning`.
Enable C++ pmap. On CPU: ``` name old cpu/op new cpu/op delta pmap_trivial_2_devices 128µs ± 6% 14µs ± 3% -89.06% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 212µs ± 2% 35µs ± 1% -83.54% (p=0.008 n=5+5) pmap_trivial_8_devices 215µs ± 1% 40µs ± 4% -81.31% (p=0.008 n=5+5) pmap_simple_2_devices 123µs ± 5% 15µs ± 6% -87.70% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 211µs ± 3% 35µs ± 2% -83.24% (p=0.008 n=5+5) pmap_simple_8_devices 217µs ± 5% 40µs ± 2% -81.68% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices_100_args 5.42ms ± 7% 0.52ms ± 2% -90.44% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.5ms ±21% 17.5ms ±37% -34.18% (p=0.008 n=5+5) sda_index_1 7.45µs ± 6% 7.53µs ± 6% ~ (p=0.222 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) name old time/op new time/op delta pmap_trivial_2_devices 136µs ± 8% 19µs ± 3% -86.08% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 216µs ± 3% 39µs ± 2% -81.94% (p=0.008 n=5+5) pmap_trivial_8_devices 219µs ± 2% 49µs ±38% -77.67% (p=0.008 n=5+5) pmap_simple_2_devices 130µs ± 5% 20µs ± 5% -84.38% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 216µs ± 3% 39µs ± 5% -81.71% (p=0.008 n=5+5) pmap_simple_8_devices 221µs ± 6% 43µs ± 1% -80.41% (p=0.016 n=5+4) pmap_simple_dispatch_8_devices_100_args 5.52ms ± 7% 0.59ms ± 2% -89.28% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.6ms ±21% 17.6ms ±37% -34.04% (p=0.008 n=5+5) sda_index_1 7.48µs ± 8% 7.53µs ± 6% ~ (p=0.310 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) ``` PiperOrigin-RevId: 401274089
2021-10-06 10:07:41 -07:00
* New features:
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`)
## jaxlib 0.1.72 (Oct 12, 2021)
* Breaking changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
## jax 0.2.21 (Sept 23, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`#7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`#7747` {jax-issue}`#7802` {jax-issue}`#7907`).
See {jax-issue}`#7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
2021-09-27 11:38:36 -07:00
* `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
now return `False` ({jax-issue}`7927`).
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`#7936`).
2021-09-02 15:38:47 -07:00
## jax 0.2.20 (Sept 2, 2021)
2021-08-12 21:17:53 -07:00
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20).
2021-08-26 11:12:16 -07:00
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
({jax-issue}`#7662`)
2021-08-12 21:17:53 -07:00
## jaxlib 0.1.71 (Sep 1, 2021)
* Breaking changes:
* Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 10.2 and CUDA 11.1+.
2021-08-12 21:17:53 -07:00
## jax 0.2.19 (Aug 12, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19).
2021-07-29 09:18:01 -04:00
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The `jit` decorator has been added around the implementation of a number of
operators on JAX arrays. This speeds up dispatch times for common
operators such as `+`.
This change should largely be transparent to most users. However, there is
one known behavioral change, which is that large integer constants may now
produce an error when passed directly to a JAX operator
(e.g., `x + 2**40`). The workaround is to cast the constant to an
explicit type (e.g., `np.float64(2**40)`).
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
({jax-issue}`#7317`)
2021-08-12 21:17:53 -07:00
* Bug fixes:
* Some leaked trace errors from the previous release ({jax-issue}`#7613`)
## jaxlib 0.1.70 (Aug 9, 2021)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
2021-07-29 09:18:01 -04:00
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.
## jax 0.2.18 (July 21 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18).
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.
* New features:
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
2021-08-02 17:57:09 -07:00
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
2021-07-09 21:16:48 -07:00
## jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.
## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
problem.
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
AWN-enabled reduction over named axes in reverse-mode AD Previously, reverse-mode AD operators inside JAX maps always meant "compute a gradient (or VJP, etc.) for each axis index in the map". For instance, `vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`. In batching tracer terms, this "elementwise" behavior means that, if any inputs to a function being transposed are mapped, the cotangents of all inputs, even unmapped ones, would also be mapped. But a user might want them to be unmapped (if, for instance, they're interested in a total gradient rather than a per-example gradient). They could always reduce (`psum`) the cotangents afterwards, but computing mapped cotangents in the first place would likely be an unacceptable waste of memory and can't necessarily be optimized away. If we want to fuse these reductions into reverse-mode autodiff itself, we need the backward_pass logic and/or transpose rules to know about whether primal values are mapped or unmapped. This is made possible by avals-with-names, which encodes that information in the avals of the primal jaxpr. Putting things together, **this change adds an option to reverse-mode AD APIs that indicates which named axes should be reduced over in the backward pass in situations where they were broadcasted over in the forward pass**. All other named axes will be treated in the current elementwise way. This has the effect of making APIs like `grad` behave akin to collectives like `psum`: they act collectively over axes that are named explicitly, and elementwise otherwise. Since avals-with-names is currently enabled only in `xmap`, this behavior is only available in that context for now. It's also missing some optimizations: - reductions aren't fused into any first-order primitives (e.g. a `pdot` should have a named contracting axis added rather than being followed by a `psum`; this can be implemented by putting these primitives into `reducing_transposes`) - reductions are performed eagerly, even over axes that are mapped to hardware resources (the optimal thing to do would be to reduce eagerly over any vectorized axis component while delaying the reduction over any hardware-mapped component until the end of the overall backward pass; this would require a way to represent these partially-reduced values) PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).
2021-06-23 14:42:36 -07:00
2021-06-23 14:42:36 -07:00
## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).
2021-06-23 11:55:40 -07:00
## jax 0.2.15 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
* [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.
2021-06-10 13:15:53 -07:00
* Breaking changes:
2021-06-10 12:12:13 -04:00
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
2021-06-10 13:15:53 -07:00
* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).
2021-06-23 11:55:40 -07:00
## jaxlib 0.1.68 (June 23 2021)
* Bug fixes:
* Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to
CPU.
2021-06-10 13:15:53 -07:00
## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
2021-06-10 13:15:53 -07:00
({jax-issue}`#6827`).
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
* The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
* New SciPy function {py:func}`jax.scipy.special.lpmn`.
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
as JAX ({jax-issue}`#6883`).
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
properly to apply only during the just-in-time conversion
({jax-issue}`#6720`).
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`#6717`).
* The {func}`jax2tf.convert` now has support for inequality comparisons and
2021-06-10 13:15:53 -07:00
min/max for complex numbers ({jax-issue}`#6892`).
## jaxlib 0.1.67 (May 17 2021)
## jaxlib 0.1.66 (May 11 2021)
* New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with
CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
the CUDA 11.1 wheel for those versions (cuda111).
2021-05-04 09:28:04 -04:00
* Jaxlib now bundles `libdevice.10.bc` in CUDA wheels. There should be no need
to point JAX to a CUDA installation to find this file.
* Added automatic support for static keyword arguments to the {func}`jit`
implementation.
* Added support for pretransformation exception traces.
* Initial support for pruning unused arguments from {func}`jit` -transformed
computations.
Pruning is still a work in progress.
* Improved the string representation of {class}`PyTreeDef` objects.
* Added support for XLA's variadic ReduceWindow.
* Bug fixes:
* Fixed a bug in the remote cloud TPU support when large numbers of arguments
are passed to a computation.
* Fix a bug that meant that JAX garbage collection was not triggered by
{func}`jit` transformed functions.
## jax 0.2.13 (May 3 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`#6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
2021-05-03 11:27:07 -04:00
* Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
2021-05-04 09:28:04 -04:00
* The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an `UnfilteredStackTrace` exception
containing the original trace as the `__cause__` of the filtered exception.
Filtered stack traces now also work with Python 3.6.
* If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a `__cause__` of
the exception a `JaxStackTraceBeforeTransformation` object that contains the
stack trace that created the original operation in the forward pass.
Requires jaxlib 0.1.66.
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`.
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`#6360`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`#6572`).
2021-04-01 10:11:52 -07:00
## jaxlib 0.1.65 (April 7 2021)
2021-04-01 10:11:52 -07:00
## jax 0.2.12 (April 1 2021)
2021-04-07 11:24:31 +03:00
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/design_notes/omnistaging.md)
2021-03-25 16:44:58 -07:00
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
`OverflowError` rather than having their value silently truncated.
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
2021-04-01 09:29:22 -07:00
* {func}`jax.random.randint` clips rather than wraps of out-of-bounds limits, and can now generate
integers in the full range of the specified dtype ({jax-issue}`#5868`)
2021-03-23 19:33:04 -07:00
## jax 0.2.11 (March 23 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
2021-03-05 11:07:50 -08:00
* New features:
2021-03-23 19:33:04 -07:00
* [#6112](https://github.com/google/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
2021-03-05 11:07:50 -08:00
* Bug fixes:
2021-03-23 19:33:04 -07:00
* [#6136](https://github.com/google/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums`
* [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with
incomplete beta functions
* [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
tracing
* [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats
2021-03-05 11:07:50 -08:00
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.
2021-03-05 11:07:50 -08:00
## jaxlib 0.1.64 (March 18 2021)
## jaxlib 0.1.63 (March 17 2021)
2021-03-09 17:55:40 -08:00
2021-03-05 11:07:50 -08:00
## jax 0.2.10 (March 5 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`#5627`)
and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
2021-03-05 11:07:50 -08:00
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
* Breaking changes:
* JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression `jnp.bfloat16(1) + 0.1 * jnp.arange(10)`
previously returned a `float64` array, and now returns a `bfloat16` array.
JAX's type promotion behavior is described at {ref}`type-promotion`.
* {func}`jax.numpy.linspace` now computes the floor of integer values, i.e.,
rounding towards -inf rather than 0. This change was made to match NumPy
1.20.0.
* {func}`jax.numpy.i0` no longer accepts complex numbers. Previously the
function computed the absolute value of complex arguments. This change was
made to match the semantics of NumPy 1.20.0.
* Several {mod}`jax.numpy` functions no longer accept tuples or lists in place
of array arguments: {func}`jax.numpy.pad`, :func`jax.numpy.ravel`,
{func}`jax.numpy.repeat`, {func}`jax.numpy.reshape`.
In general, {mod}`jax.numpy` functions should be used with scalars or array arguments.
2021-03-09 17:55:40 -08:00
## jaxlib 0.1.62 (March 9 2021)
2021-03-05 11:07:50 -08:00
* New features:
* jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the `--target_cpu_features` flag
to `build.py`. `--target_cpu_features` also replaces
`--enable_march_native`.
## jaxlib 0.1.61 (February 12 2021)
## jaxlib 0.1.60 (Febuary 3 2021)
* Bug fixes:
* Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
* `bool`, `int8`, and `uint8` are now considered safe to cast to
`bfloat16` NumPy extension type.
## jax 0.2.9 (January 26 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
* Add {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`.
These are context managers which allow X64 mode to be temporarily enabled/disabled
within a session.
* Breaking changes:
* {func}`jax.ops.segment_sum` now drops segment IDs that are out of range rather
than wrapping them into the segment ID space. This was done for performance
reasons.
## jaxlib 0.1.59 (January 15 2021)
## jax 0.2.8 (January 12 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`#5244`)
* Add {func}`jax.experimental.host_callback.call` to call a custom Python
function on the host and return a result to the device computation.
({jax-issue}`#5243`)
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`#5156`)
2021-08-02 17:57:09 -07:00
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
2021-03-05 11:07:50 -08:00
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`#5182`).
* Breaking changes:
* `jax.numpy.pad` now takes keyword arguments. Positional argument `constant_values`
has been removed. In addition, passing unsupported keyword arguments raises an error.
* Changes for {func}`jax.experimental.host_callback.id_tap` ({jax-issue}`#5243`):
* Removed support for `kwargs` for {func}`jax.experimental.host_callback.id_tap`.
(This support has been deprecated for a few months.)
* Changed the printing of tuples for {func}`jax.experimental.host_callback.id_print`
to use '(' instead of '['.
* Changed the {func}`jax.experimental.host_callback.id_print` in presence of JVP
to print a pair of primal and tangent. Previously, there were two separate
print operations for the primals and the tangent.
* `host_callback.outfeed_receiver` has been removed (it is not necessary,
and was deprecated a few months ago).
* New features:
* New flag for debugging `inf`, analagous to that for `NaN` ({jax-issue}`#5224`).
## jax 0.2.7 (Dec 4 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
2021-08-02 17:57:09 -07:00
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
2021-03-05 11:07:50 -08:00
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
* Bug fixes:
* Fix higher-than-second order derivatives of `jax.numpy.sinc` at zero
* Fix some hard-to-hit bugs around symbolic zeros in transpose rules
* Breaking changes:
* `jax.experimental.optix` has been deleted, in favor of the standalone
`optax` Python package.
* indexing of JAX arrays with non-tuple sequences now raises a `TypeError`. This type of indexing
has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
See {jax-issue}`#4564`.
## jax 0.2.6 (Nov 18 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
2021-03-05 11:07:50 -08:00
* Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`#4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
it returned `complex128`.
* Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
`jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
`jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
the first returned `float64` and the second returned `float16`.
* The contents of the (undocumented) `jax.lax_linalg` linear algebra module
are now exposed publicly as `jax.lax.linalg`.
* `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
({jax-issue}`#4877`).
This required changing the result for a given seed in a few particular cases:
* With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
`[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
* Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
rather than a `TypeError`. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with `jax_enable_x64=False`
outside JIT, you can use:
```
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
```
* DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
to access its value while it has been deleted.
## jaxlib 0.1.58 (January 12ish 2021)
* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
`np.cint`) instead of standard types (e.g., `np.int32`). (#4903)
* Fixed a crash when constant-folding certain int16 operations. (#4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.
## jaxlib 0.1.57 (November 12 2020)
* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (#4651).
* Add support for retaining ownership when passing arrays to DLPack (#4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
* Fixed a bug in profiler where tools are missing (#4427).
* Dropped support for CUDA 10.0.
## jax 0.2.5 (October 27 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`.
* Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
2021-03-05 11:07:50 -08:00
## jax 0.2.4 (October 19 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`.
* Deprecations
* Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}`#4564`.
## jaxlib 0.1.56 (October 14, 2020)
## jax 0.2.3 (October 14 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation
## jax 0.2.2 (October 13 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2).
## jax 0.2.1 (October 6 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1).
* Improvements:
* As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
{py:func}`jax.experimental.host_callback.id_tap` is not used in the computation.
## jax (0.2.0) (September 23 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0).
* Improvements:
2021-03-25 16:44:58 -07:00
* Omnistaging on by default. See {jax-issue}`#3370` and
[omnistaging](https://github.com/google/jax/blob/main/design_notes/omnistaging.md)
2021-03-05 11:07:50 -08:00
## jax (0.1.77) (September 15 2020)
* Breaking changes:
* New simplified interface for {py:func}`jax.experimental.host_callback.id_tap` (#4101)
## jaxlib 0.1.55 (September 8, 2020)
* Update XLA:
* Fix bug in DLPackManagedTensorToBuffer (#4196)
## jax 0.1.76 (September 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76).
## jax 0.1.75 (July 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75).
* Bug Fixes:
* make jnp.abs() work for unsigned inputs (#3914)
* Improvements:
* "Omnistaging" behavior added behind a flag, disabled by default (#3370)
## jax 0.1.74 (July 29, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
* BFGS (#3101)
2021-08-02 17:57:09 -07:00
* TPU support for half-precision arithmetic (#3878)
2021-03-05 11:07:50 -08:00
* Bug Fixes:
* Prevent some accidental dtype warnings (#3874)
* Fix a multi-threading bug in custom derivatives (#3845, #3869)
* Improvements:
* Faster searchsorted implementation (#3873)
* Better test coverage for jax.numpy sorting algorithms (#3836)
## jaxlib 0.1.52 (July 22, 2020)
* Update XLA.
## jax 0.1.73 (July 22, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73).
* The minimum jaxlib version is now 0.1.51.
* New Features:
* jax.image.resize. (#3703)
* hfft and ihfft (#3664)
* jax.numpy.intersect1d (#3726)
* jax.numpy.lexsort (#3812)
* `lax.scan` and the `scan` primitive support an `unroll`
parameter for loop unrolling when lowering to XLA
({jax-issue}`#3738`).
* Bug Fixes:
* Fix reduction repeated axis error (#3618)
* Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
* make psum transpose handle zero cotangents (#3653)
* Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
* Support differentiation through jax.lax.all_to_all (#3733)
* address nan issue in jax.scipy.special.zeta (#3777)
* Improvements:
* Many improvements to jax2tf
* Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
* Enable XLA SPMD partitioning by default. (#3151)
* Add support for 0d transpose convolution (#3643)
* Make LU gradient work for low-rank matrices (#3610)
* support multiple_results and custom JVPs in jet (#3657)
* Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
* Implement complex convolutions on CPU and GPU. (#3735)
* Make jnp.take work for empty slices of empty arrays. (#3751)
* Relax dimension ordering rules for dot_general. (#3778)
* Enable buffer donation for GPU. (#3800)
* Add support for base dilation and window dilation to reduce window op… (#3803)
## jaxlib 0.1.51 (July 2, 2020)
* Update XLA.
* Add new runtime support for host_callback.
## jax 0.1.72 (June 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72).
* Bug fixes:
* Fix an odeint bug introduced in the previous release, see
{jax-issue}`#3587`.
## jax 0.1.71 (June 25, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71).
* The minimum jaxlib version is now 0.1.48.
* Bug fixes:
* Allow `jax.experimental.ode.odeint` dynamics functions to close over
values with respect to which we're differentiating
{jax-issue}`#3562`.
## jaxlib 0.1.50 (June 25, 2020)
* Add support for CUDA 11.0.
* Drop support for CUDA 9.2 (we only maintain support for the last four CUDA
versions.)
* Update XLA.
## jaxlib 0.1.49 (June 19, 2020)
* Bug fixes:
* Fix build issue that could result in slow compiles
(<https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1>)
## jaxlib 0.1.48 (June 12, 2020)
* New features:
* Adds support for fast traceback collection.
* Adds preliminary support for on-device heap profiling.
* Implements `np.nextafter` for `bfloat16` types.
* Complex128 support for FFTs on CPU and GPU.
* Bugfixes:
* Improved float64 `tanh` accuracy on GPU.
* float64 scatters on GPU are much faster.
* Complex matrix multiplication on CPU should be much faster.
* Stable sorts on CPU should actually be stable now.
* Concurrency bug fix in CPU backend.
## jax 0.1.70 (June 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70).
* New features:
* `lax.switch` introduces indexed conditionals with multiple
branches, together with a generalization of the `cond`
primitive
{jax-issue}`#3318`.
## jax 0.1.69 (June 3, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69).
## jax 0.1.68 (May 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68).
* New features:
* {func}`lax.cond` supports a single-operand form, taken as the argument
to both branches
{jax-issue}`#2993`.
* Notable changes:
* The format of the `transforms` keyword for the {func}`jax.experimental.host_callback.id_tap`
primitive has changed {jax-issue}`#3132`.
## jax 0.1.67 (May 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67).
* New features:
* Support for reduction over subsets of a pmapped axis using `axis_index_groups`
{jax-issue}`#2382`.
* Experimental support for printing and calling host-side Python function from
compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)
({jax-issue}`#3006`).
* Notable changes:
* The visibility of names exported from {mod}`jax.numpy` has been
tightened. This may break code that was making use of names that were
previously exported accidentally.
## jaxlib 0.1.47 (May 8, 2020)
* Fixes crash for outfeed.
## jax 0.1.66 (May 5, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66).
* New features:
* Support for `in_axes=None` on {func}`pmap`
{jax-issue}`#2896`.
## jaxlib 0.1.46 (May 5, 2020)
* Fixes crash for linear algebra functions on Mac OS X (#432).
* Fixes an illegal instruction crash caused by using AVX512 instructions when
an operating system or hypervisor disabled them (#2906).
## jax 0.1.65 (April 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65).
* New features:
* Differentiation of determinants of singular matrices
{jax-issue}`#2809`.
* Bug fixes:
* Fix {func}`odeint` differentiation with respect to time of ODEs with
time-dependent dynamics {jax-issue}`#2817`,
also add ODE CI testing.
* Fix {func}`lax_linalg.qr` differentiation
{jax-issue}`#2867`.
## jaxlib 0.1.45 (April 21, 2020)
* Fixes segfault: {jax-issue}`#2755`
* Plumb is_stable option on Sort HLO through to Python.
## jax 0.1.64 (April 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64).
* New features:
* Add syntactic sugar for functional indexed updates
{jax-issue}`#2684`.
* Add {func}`jax.numpy.linalg.multi_dot` {jax-issue}`#2726`.
* Add {func}`jax.numpy.unique` {jax-issue}`#2760`.
* Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
* Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
* Add more primitive rules for {func}`jax.experimental.jet`.
* Bug fixes:
* Fix {func}`logaddexp` and {func}`logaddexp2` differentiation at zero {jax-issue}`#2107`.
* Improve memory usage in reverse-mode autodiff without {func}`jit`
{jax-issue}`#2719`.
* Better errors:
* Improves error message for reverse-mode differentiation of {func}`lax.while_loop`
{jax-issue}`#2129`.
## jaxlib 0.1.44 (April 16, 2020)
* Fixes a bug where if multiple GPUs of different models were present, JAX
would only compile programs suitable for the first GPU.
* Bugfix for `batch_group_count` convolutions.
* Added precompiled SASS for more GPU versions to avoid startup PTX compilation
hang.
## jax 0.1.63 (April 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63).
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
* Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`.
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`.
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
* Add `callback_transform` {jax-issue}`#2665`.
* Implement `rollaxis`, `convolve`/`correlate` 1d & 2d, `copysign`,
`trunc`, `roots`, and `quantile`/`percentile` interpolation options.
## jaxlib 0.1.43 (March 31, 2020)
* Fixed a performance regression for Resnet-50 on GPU.
## jax 0.1.62 (March 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62).
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the
convention `0. * nan == 0.`. This change means some programs when
differentiated will produce nans when they previously produced correct
values, though it ensures nans rather than silently incorrect results are
produced for other programs. See #2447 and #1052 for details.
* Added an `all_gather` parallel convenience function.
* More type annotations in core code.
## jaxlib 0.1.42 (March 19, 2020)
* jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This
release fixes it again.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
## jax 0.1.61 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61).
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
## jax 0.1.60 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60).
* New features:
* {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
the user to specify arguments that should be treated as compile-time
constants and should be broadcasted to all devices. It works analogously to
`static_argnums` in {py:func}`jax.jit`.
* Improved error messages for when tracers are mistakenly saved in global state.
* Added {py:func}`jax.nn.one_hot` utility function.
* Added {mod}`jax.experimental.jet` for exponentially faster
higher-order automatic differentiation.
* Added more correctness checking to arguments of {py:func}`jax.lax.broadcast_in_dim`.
* The minimum jaxlib version is now 0.1.41.
## jaxlib 0.1.40 (March 4, 2020)
* Adds experimental support in Jaxlib for TensorFlow profiler, which allows
tracing of CPU and GPU computations from TensorBoard.
* Includes prototype support for multihost GPU computations that communicate via
NCCL.
* Improves performance of NCCL collectives on GPU.
* Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and
RandomGamma implementations.
* Supports device assignments known at XLA compilation time.
## jax 0.1.59 (February 11, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59).
* Breaking changes
* The minimum jaxlib version is now 0.1.38.
* Simplified {py:class}`Jaxpr` by removing the `Jaxpr.freevars` and
`Jaxpr.bound_subjaxprs`. The call primitives (`xla_call`, `xla_pmap`,
`sharded_call`, and `remat_call`) get a new parameter `call_jaxpr` with a
fully-closed (no `constvars`) jaxpr. Also, added a new field `call_primitive`
to primitives.
* New features:
* Reverse-mode automatic differentiation (e.g. `grad`) of `lax.cond`, making it
now differentiable in both modes ({jax-issue}`#2091`)
* JAX now supports DLPack, which allows sharing CPU and GPU arrays in a
zero-copy way with other libraries, such as PyTorch.
* JAX GPU DeviceArrays now support `__cuda_array_interface__`, which is another
zero-copy protocol for sharing GPU arrays with other libraries such as CuPy
and Numba.
* JAX CPU device buffers now implement the Python buffer protocol, which allows
zero-copy buffer sharing between JAX and NumPy.
* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
## jaxlib 0.1.39 (February 11, 2020)
* Updates XLA.
## jaxlib 0.1.38 (January 29, 2020)
* CUDA 9.0 is no longer supported.
* CUDA 10.2 wheels are now built by default.
## jax 0.1.58 (January 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58).
* Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on
January 1, 2020. Please update to Python 3.5 or newer.
* New features
> > * Forward-mode automatic differentiation (`jvp`) of while loop
> ({jax-issue}`#1980`)
> * New NumPy and SciPy functions:
2021-03-09 17:55:40 -08:00
>
2021-03-05 11:07:50 -08:00
> * {py:func}`jax.numpy.fft.fft2`
> * {py:func}`jax.numpy.fft.ifft2`
> * {py:func}`jax.numpy.fft.rfft`
> * {py:func}`jax.numpy.fft.irfft`
> * {py:func}`jax.numpy.fft.rfft2`
> * {py:func}`jax.numpy.fft.irfft2`
> * {py:func}`jax.numpy.fft.rfftn`
> * {py:func}`jax.numpy.fft.irfftn`
> * {py:func}`jax.numpy.fft.fftfreq`
> * {py:func}`jax.numpy.fft.rfftfreq`
> * {py:func}`jax.numpy.linalg.matrix_rank`
> * {py:func}`jax.numpy.linalg.matrix_power`
> * {py:func}`jax.scipy.special.betainc`
> * Batched Cholesky decomposition on GPU now uses a more efficient batched
> kernel.
### Notable bug fixes
* With the Python 3 upgrade, JAX no longer depends on `fastcache`, which should
help with installation.