rocm_jax/docs/CHANGELOG.rst
2020-04-21 18:18:43 -07:00

215 lines
9.1 KiB
ReStructuredText

Change Log
==========
.. This is a comment.
Remember to leave an empty line before the start of an itemized list,
and to align the itemized text with the first line of an item.
.. PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
These are the release notes for JAX.
jax 0.1.65 (unreleased)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.64...master>`_.
jaxlib 0.1.46 (unreleased)
------------------------------
jaxlib 0.1.45 (April 21, 2020)
------------------------------
* Fixes segfault: https://github.com/google/jax/issues/2755
* Plumb is_stable option on Sort HLO through to Python.
jax 0.1.64 (April 21, 2020)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64>`_.
* New features:
* Add syntactic sugar for functional indexed updates
`#2684 <https://github.com/google/jax/issues/2684>`_.
* Add :func:`jax.numpy.linalg.multi_dot` `#2726 <https://github.com/google/jax/issues/2726>`_.
* Add :func:`jax.numpy.unique` `#2760 <https://github.com/google/jax/issues/2760>`_.
* Add :func:`jax.numpy.rint` `#2724 <https://github.com/google/jax/issues/2724>`_.
* Add :func:`jax.numpy.rint` `#2724 <https://github.com/google/jax/issues/2724>`_.
* Add more primitive rules for :func:`jax.experimental.jet`.
* Bug fixes:
* Fix :func:`logaddexp` and :func:`logaddexp2` differentiation at zero `#2107
<https://github.com/google/jax/issues/2107>`_.
* Improve memory usage in reverse-mode autodiff without :func:`jit`
`#2719 <https://github.com/google/jax/issues/2719>`_.
* Better errors:
* Improves error message for reverse-mode differentiation of :func:`lax.while_loop`
`#2129 <https://github.com/google/jax/issues/2129>`_.
jaxlib 0.1.44 (April 16, 2020)
------------------------------
* Fixes a bug where if multiple GPUs of different models were present, JAX
would only compile programs suitable for the first GPU.
* Bugfix for ``batch_group_count`` convolutions.
* Added precompiled SASS for more GPU versions to avoid startup PTX compilation
hang.
jax 0.1.63 (April 12, 2020)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63>`_.
* Added ``jax.custom_jvp`` and ``jax.custom_vjp`` from `#2026 <https://github.com/google/jax/pull/2026>`_, see the `tutorial notebook <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_. Deprecated ``jax.custom_transforms`` and removed it from the docs (though it still works).
* Add ``scipy.sparse.linalg.cg`` `#2566 <https://github.com/google/jax/pull/2566>`_.
* Changed how Tracers are printed to show more useful information for debugging `#2591 <https://github.com/google/jax/pull/2591>`_.
* Made ``jax.numpy.isclose`` handle ``nan`` and ``inf`` correctly `#2501 <https://github.com/google/jax/pull/2501>`_.
* Added several new rules for ``jax.experimental.jet`` `#2537 <https://github.com/google/jax/pull/2537>`_.
* Fixed ``jax.experimental.stax.BatchNorm`` when ``scale``/``center`` isn't provided.
* Fix some missing cases of broadcasting in ``jax.numpy.einsum`` `#2512 <https://github.com/google/jax/pull/2512>`_.
* Implement ``jax.numpy.cumsum`` and ``jax.numpy.cumprod`` in terms of a parallel prefix scan `#2596 <https://github.com/google/jax/pull/2596>`_ and make ``reduce_prod`` differentiable to arbitray order `#2597 <https://github.com/google/jax/pull/2597>`_.
* Add ``batch_group_count`` to ``conv_general_dilated`` `#2635 <https://github.com/google/jax/pull/2635>`_.
* Add docstring for ``test_util.check_grads`` `#2656 <https://github.com/google/jax/pull/2656>`_.
* Add ``callback_transform`` `#2665 <https://github.com/google/jax/pull/2665>`_.
* Implement ``rollaxis``, ``convolve``/``correlate`` 1d & 2d, ``copysign``,
``trunc``, ``roots``, and ``quantile``/``percentile`` interpolation options.
jaxlib 0.1.43 (March 31, 2020)
------------------------------
* Fixed a performance regression for Resnet-50 on GPU.
jax 0.1.62 (March 21, 2020)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62>`_.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function ``lax._safe_mul``, which implemented the
convention ``0. * nan == 0.``. This change means some programs when
differentiated will produce nans when they previously produced correct
values, though it ensures nans rather than silently incorrect results are
produced for other programs. See #2447 and #1052 for details.
* Added an ``all_gather`` parallel convenience function.
* More type annotations in core code.
jaxlib 0.1.42 (March 19, 2020)
------------------------------
* jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This
release fixes it again.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
jax 0.1.61 (March 17, 2020)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61>`_.
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
jax 0.1.60 (March 17, 2020)
---------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60>`_.
* New features:
* :py:func:`jax.pmap` has ``static_broadcast_argnums`` argument which allows
the user to specify arguments that should be treated as compile-time
constants and should be broadcasted to all devices. It works analogously to
``static_argnums`` in :py:func:`jax.jit`.
* Improved error messages for when tracers are mistakenly saved in global state.
* Added :py:func:`jax.nn.one_hot` utility function.
* Added :py:module:`jax.experimental.jet` for exponentially faster
higher-order automatic differentiation.
* Added more sanity checking to arguments of :py:func:`jax.lax.broadcast_in_dim`.
* The minimum jaxlib version is now 0.1.41.
jaxlib 0.1.40 (March 4, 2020)
-------------------------------
* Adds experimental support in Jaxlib for TensorFlow profiler, which allows
tracing of CPU and GPU computations from TensorBoard.
* Includes prototype support for multihost GPU computations that communicate via
NCCL.
* Improves performance of NCCL collectives on GPU.
* Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and
RandomGamma implementations.
* Supports device assignments known at XLA compilation time.
jax 0.1.59 (February 11, 2020)
------------------------------
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59>`_.
* Breaking changes
* The minimum jaxlib version is now 0.1.38.
* Simplified :py:class:`Jaxpr` by removing the ``Jaxpr.freevars`` and
``Jaxpr.bound_subjaxprs``. The call primitives (``xla_call``, ``xla_pmap``,
``sharded_call``, and ``remat_call``) get a new parameter ``call_jaxpr`` with a
fully-closed (no ``constvars``) jaxpr. Also, added a new field ``call_primitive``
to primitives.
* New features:
* Reverse-mode automatic differentiation (e.g. ``grad``) of ``lax.cond``, making it
now differentiable in both modes (https://github.com/google/jax/pull/2091)
* JAX now supports DLPack, which allows sharing CPU and GPU arrays in a
zero-copy way with other libraries, such as PyTorch.
* JAX GPU DeviceArrays now support ``__cuda_array_interface__``, which is another
zero-copy protocol for sharing GPU arrays with other libraries such as CuPy
and Numba.
* JAX CPU device buffers now implement the Python buffer protocol, which allows
zero-copy buffer sharing between JAX and NumPy.
* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
jaxlib 0.1.39 (February 11, 2020)
--------------------------------
* Updates XLA.
jaxlib 0.1.38 (January 29, 2020)
--------------------------------
* CUDA 9.0 is no longer supported.
* CUDA 10.2 wheels are now built by default.
jax 0.1.58 (January 28, 2020)
-----------------------------
* `GitHub commits <https://github.com/google/jax/compare/46014da21...jax-v0.1.58>`_.
* Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on
January 1, 2020. Please update to Python 3.5 or newer.
* New features
* Forward-mode automatic differentiation (`jvp`) of while loop
(https://github.com/google/jax/pull/1980)
* New NumPy and SciPy functions:
* :py:func:`jax.numpy.fft.fft2`
* :py:func:`jax.numpy.fft.ifft2`
* :py:func:`jax.numpy.fft.rfft`
* :py:func:`jax.numpy.fft.irfft`
* :py:func:`jax.numpy.fft.rfft2`
* :py:func:`jax.numpy.fft.irfft2`
* :py:func:`jax.numpy.fft.rfftn`
* :py:func:`jax.numpy.fft.irfftn`
* :py:func:`jax.numpy.fft.fftfreq`
* :py:func:`jax.numpy.fft.rfftfreq`
* :py:func:`jax.numpy.linalg.matrix_rank`
* :py:func:`jax.numpy.linalg.matrix_power`
* :py:func:`jax.scipy.special.betainc`
* Batched Cholesky decomposition on GPU now uses a more efficient batched
kernel.
Notable bug fixes
^^^^^^^^^^^^^^^^^
* With the Python 3 upgrade, JAX no longer depends on ``fastcache``, which should
help with installation.