rocm_jax/CHANGELOG.md

3176 lines
157 KiB
Markdown
Raw Permalink Normal View History

2021-03-05 11:07:50 -08:00
# Change log
Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
For the changes specific to the experimental Pallas APIs,
see {ref}`pallas-changelog`.
2021-03-05 11:07:50 -08:00
JAX follows Effort-based versioning; for a discussion of this and JAX's API
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
NumPy version support policy, refer to {ref}`version-support-policy`.
2021-03-05 11:07:50 -08:00
<!--
Remember to align the itemized text with the first line of an item within a list.
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
2021-03-05 11:07:50 -08:00
-->
## Unreleased
* New Features
* Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`,
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size.
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
without replacement.
## jax 0.5.2 (Mar 4, 2025)
Patch release of 0.5.1
* Bug fixes
* Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1
2025-02-18 09:29:24 -08:00
## jax 0.5.1 (Feb 24, 2025)
* New Features
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
2025-03-12 18:15:14 -04:00
* Added {func}`jax.random.multinomial`.
{jax-issue}`#25955` for more details.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
multi-process CPU communication works out-of-the-box.
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.
* Deprecations
* The internal function `linear_util.wrap_init` and the constructor
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
a limited time, a `DeprecationWarning` is printed if
`jax.extend.linear_util.wrap_init` is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
* Bug fixes
2025-02-18 09:29:24 -08:00
* TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(`sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'`).
We hope to improve this further in future releases.
* Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.
2025-01-17 10:28:03 -05:00
## jax 0.5.0 (Jan 17, 2025)
2024-12-17 18:03:28 -05:00
As of this release, JAX now uses
[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
2025-01-17 10:28:03 -05:00
Since this release makes a breaking change to PRNG key semantics that
may require users to update their code, we are bumping the "meso" version of JAX
to signify this.
* Breaking changes
* Enable `jax_threefry_partitionable` by default (see
[the update note](https://github.com/jax-ml/jax/discussions/18480)).
2025-01-17 10:28:03 -05:00
* This release drops support for Mac x86 wheels. Mac ARM of course remains
supported. For a recent discussion, see
https://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision:
* The Mac x86 build (only) has a number of test failures and crashes. We
would prefer to ship no release than a broken release.
* Mac x86 hardware is end-of-life and cannot be easily obtained for
developers at this point. So it is difficult for us to fix this kind of
problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing
to help support that platform: in particular, we would need the JAX test
suite to pass cleanly on Mac x86 before we could ship releases again.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025.
2025-01-06 08:21:59 -08:00
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).
* {func}`jax.numpy.linalg.solve` no longer supports batched 1D arguments
on the right hand side. To recover the previous behavior in these cases,
use `solve(a, b[..., None]).squeeze(-1)`.
2024-12-19 12:02:36 +00:00
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.
* The AOT lowering `.as_text()` method now supports the `debug_info` option
to include debugging information, e.g., source location, in the output.
2025-01-17 10:28:03 -05:00
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
2024-12-20 11:26:04 +00:00
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.
2024-12-17 18:03:28 -05:00
## jax 0.4.38 (Dec 17, 2024)
* Breaking Changes
* `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a
single-element `list[dict[str, float]]`).
* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.
2024-12-10 11:11:32 -08:00
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
for information on the compatibility guarantees of these semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
* from {mod}`jax.numpy`: `round_`.
2024-12-10 11:11:32 -08:00
* New Features
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
* Added {func}`jax.lax.split`. This is a primitive version of
{func}`jax.numpy.split`, added because it yields a more compact
transpose during automatic differentiation.
## jax 0.4.37 (Dec 9, 2024)
2024-12-09 15:38:37 -05:00
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
* Bug fixes
2024-12-09 15:38:37 -05:00
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
* Fix a bug that will throw `index out of range` error in
{func}`jax.lax.while_loop` if the user register pytree node class with
different aux data for the flatten and flatten_with_path.
2024-12-09 15:38:37 -05:00
* Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
## jax 0.4.36 (Dec 5, 2024)
2024-10-22 17:02:22 -04:00
* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels, `post_process_call`,
`new_base_main`, `custom_bind`, and so on. The change should only affect
users that use JAX internals.
If you do use JAX internals then you may need to
update your code (see
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
you need help updating your code then please file a bug.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.
* In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed
after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`,
`xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`.
* The deprecated module `jax.experimental.export` has been removed. It was replaced
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
now raises an error. Previously, this returned a scalar object array.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
already.
* `jax.export.Exported.lowering_platforms`: use `platforms`.
* `jax.export.Exported.mlir_module_serialization_version`:
use `calling_convention_version`.
* `jax.export.Exported.uses_shape_polymorphism`:
use `uses_global_constants`.
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
`platforms` instead.
* The kwargs `symbolic_scope` and `symbolic_constraints` from
{func}`jax.export.symbolic_args_specs` have been removed. They were
deprecated in June 2024. Use `scope` and `constraints` instead.
2024-11-05 09:08:33 -08:00
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
Re-factor build CLI to a subcommand based approach This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script. Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions. There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time. Usage: * Building `jaxlib`: ``` python build/build.py build --wheels=jaxlib --python_version=3.10 ``` * Building `jax-cuda-plugin`: ``` python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building multiple packages: ``` python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building `jax-rocm-pjrt`: ``` python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm ``` * Using a local XLA path: ``` python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` * Updating requirements_lock.txt files: ``` python build/build.py requirements_update --python_version=3.10 ``` For more details on each argument and to see available options, run: ``` python build/build.py build --help ``` or ``` python build/build.py requirements_update --help ``` PiperOrigin-RevId: 700075411
2024-11-25 13:02:22 -08:00
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Run `python build/build.py --help` for
more details. Brief overview of the new subcommand options:
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
* `requirements_update`: Updates requirements_lock.txt files.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the `disabled_checks`
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
2024-11-14 15:23:26 -05:00
* Added {func}`jax.numpy.put_along_axis`.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`#24663` for more details.
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`#24843` for more details.
* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.
2024-10-22 17:02:22 -04:00
## jax 0.4.35 (Oct 22, 2024)
2024-10-04 10:53:28 -04:00
* Breaking Changes
* {func}`jax.numpy.isscalar` now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See {jax-issue}`#20385` for a discussion of alternatives.
* Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
operations. The semi-public API `jax.lib.xla_client.FftType` has been
deprecated.
* TPU: JAX now installs TPU support from the `libtpu` package rather than
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
will be removed in Q1 2025.
* Deprecations:
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
No JAX APIs consume this type, so there is no replacement.
* The default behavior of {func}`jax.pure_callback` and
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`#23881` for more details.
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
been deprecated. Use the JAX FFI instead.
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
`jax.lib.xla_client.ops`,
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
instead.
2024-10-06 23:16:30 +08:00
## jax 0.4.34 (October 4, 2024)
2024-09-18 18:57:03 +00:00
* New Functionality
* This release includes wheels for Python 3.13. Free-threading mode is not yet
supported.
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
formerly private `XlaRuntimeError` type.
2024-09-18 18:57:03 +00:00
* Breaking changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
* array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
* The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
`--jax_host_callback_legacy` configuration value to `True`, which means that
if your code uses `jax.experimental.host_callback` APIs, those API calls
will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the
`--jax_host_callback_legacy` to `True`. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}`#20385` for a discussion.
* Deprecations
2024-09-18 18:57:03 +00:00
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
arguments with `ndim != 1` are now deprecated, and in the future will result
in an error.
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
being deprecated in JAX v0.4.30.
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.
* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
The argument was only used by `xmap` which was removed in 0.4.31.
* `jax.tree.map(f, None, non-None)`, which previously emitted a
`DeprecationWarning`, now raises an error in a future version of jax. `None`
is only a tree-prefix of itself. To preserve the current behavior, you can
ask `jax.tree.map` to treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
* `jax.sharding.XLACompatibleSharding` has been removed. Please use
`jax.sharding.Sharding`.
* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
if a non-boolean input was provided and `dtype=bool` was specified.
* Edit implementation of {func}`jax.numpy.ldexp` to get correct gradient.
## jax 0.4.33 (September 16, 2024)
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of `libtpu`.
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
## jax 0.4.32 (September 11, 2024)
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
* New Functionality
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
to support the use of the new {ref}`ffi-tutorial` to interface with custom
C++ and CUDA code from JAX.
* Changes
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.
* Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Added new {func}`jax.process_indices` function to replace the
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc`
interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`,
{obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`,
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.
* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.
* Deprecated the following APIs:
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.
* `jax.numpy.round_` has been deprecated, following removal of the corresponding
API in NumPy 2.0. Use {func}`jax.numpy.round` instead.
* Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.
## jaxlib 0.4.32 (September 11, 2024)
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
See the 0.4.33 release notes for more details.
Introduce hermetic CUDA in Google ML projects. 1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases. [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/) [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/) [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history) 2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA. Note: use `@local_tsl` instead of `@tsl` in Tensorflow project. ``` load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) cuda_json_init_repository() load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", "CUDNN_REDISTRIBUTIONS", ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) cuda_redist_init_repositories( cuda_redistributions = CUDA_REDISTRIBUTIONS, ) cudnn_redist_init_repository( cudnn_redistributions = CUDNN_REDISTRIBUTIONS, ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") ``` PiperOrigin-RevId: 662981325
2024-08-14 10:57:53 -07:00
* Breaking changes
* This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
`XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this,
please file a JAX bug with instructions to reproduce.
Introduce hermetic CUDA in Google ML projects. 1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases. [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/) [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/) [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history) 2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA. Note: use `@local_tsl` instead of `@tsl` in Tensorflow project. ``` load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) cuda_json_init_repository() load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", "CUDNN_REDISTRIBUTIONS", ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) cuda_redist_init_repositories( cuda_redistributions = CUDA_REDISTRIBUTIONS, ) cudnn_redist_init_repository( cudnn_redistributions = CUDNN_REDISTRIBUTIONS, ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") ``` PiperOrigin-RevId: 662981325
2024-08-14 10:57:53 -07:00
* Hermetic CUDA support is added.
Hermetic CUDA uses a specific downloadable version of CUDA instead of the
users locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL
distributions, and then use CUDA libraries and tools as dependencies in
various Bazel targets. This enables more reproducible builds for JAX and its
supported CUDA versions.
* Changes
* SparseCore profiling is added.
* JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer).
## jax 0.4.31 (July 29, 2024)
* Deletion
* xmap has been deleted. Please use {func}`shard_map` as the replacement.
2024-06-21 14:57:35 -07:00
* Changes
* The minimum CuDNN version is v9.1. This was true in previous releases also,
but we now declare this version constraint formally.
2024-06-26 13:43:15 -04:00
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
supported version until December 2024.
* The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum
supported version until January 2025.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
`dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`.
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.
2024-06-21 14:57:35 -07:00
## jaxlib 0.4.31 (July 29, 2024)
* Bug fixes
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
by the jit dispatch fast path.
* Fixed a bug that meant triangular solves of batches of singular matrices
produce nonsensical finite values, instead of inf or nan (#3589, #15429).
## jax 0.4.30 (June 18, 2024)
* Changes
* JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).
* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
2024-06-13 13:14:27 -07:00
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
* `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
* Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
`x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`.
* `jax.xla_computation` is deprecated and will be removed in a future release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
## jaxlib 0.4.30 (June 18, 2024)
* Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (`pip install jax[cuda12]` or
`pip install jax[cuda12_local]`).
## jax 0.4.29 (June 10, 2024)
* Changes
* We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`).
* JAX now requires ml_dtypes version 0.4.0 or newer.
* Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
2024-06-10 09:46:15 -07:00
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
* Deprecations
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use
`jax.sharding.Sharding`.
* `jax.experimental.Exported.in_shardings` has been renamed as
`jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`.
The old names will be removed after 3 months.
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`
* from {mod}`jax.nn`: `normalize`
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
* The ``tol`` argument of {func}`jax.numpy.linalg.matrix_rank` is being
deprecated and will soon be removed. Use `rtol` instead.
* The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being
deprecated and will soon be removed. Use `rtol` instead.
* The deprecated `jax.config` submodule has been removed. To configure JAX
use `import jax` and then reference the config object via `jax.config`.
* {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
* In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been
renamed to `a` and `b` for consistency with other `beta` APIs.
* New Functionality
* Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in the `Exported` objects.
## jaxlib 0.4.29 (June 10, 2024)
* Bug fixes
* Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
raise an error in a future version of jax. `None` is only a tree-prefix of
itself. To preserve the current behavior, you can ask `jax.tree.map` to
treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
## jax 0.4.28 (May 9, 2024)
2024-05-08 18:55:53 +00:00
* Bug fixes
* Reverted a change to `make_jaxpr` that was breaking Equinox (#21116).
* Deprecations & removals
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
module. Use the ``compute_capability`` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.
2024-05-08 18:55:53 +00:00
* Changes
* The minimum jaxlib version of this release is 0.4.27.
## jaxlib 0.4.28 (May 9, 2024)
2024-05-08 18:55:53 +00:00
* Bug fixes
* Fixes a memory corruption bug in the type name of Array and JIT Python
objects in Python 3.10 or earlier.
2024-05-08 18:55:53 +00:00
* Fixed a warning `'+ptx84' is not a recognized feature for this target`
under CUDA 12.4.
* Fixed a slow compilation problem on CPU.
* Changes
* The Windows build is now built with Clang instead of MSVC.
## jax 0.4.27 (May 7, 2024)
* New Functionality
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
* Added a new config option `jax_cpu_collectives_implementation` to select the
implementation of cross-process collective operations used by the CPU backend.
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
If set to `'none'`, cross-process collective operations are disabled.
* Changes
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
of {class}`np.ndarray`. You can recover the old behavior by transforming
the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
2024-04-22 10:32:51 +05:30
`a_max` are deprecated in favor of `x` (positional only), `min`, and
`max` ({jax-issue}`20550`).
* The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
is deprecated; empty inputs to softmax are now supported without setting this.
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
related functions now raise an error, following a similar change in NumPy.
* The config option `jax_cpu_enable_gloo_collectives` is deprecated.
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
* The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
* The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
* Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
* Array-like arguments are now required in several :func:`jax.numpy` APIs,
including {func}`~jax.numpy.apply_along_axis`,
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatibility.
## jaxlib 0.4.27 (May 7, 2024)
## jax 0.4.26 (April 3, 2024)
* New Functionality
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
NumPy 2.0.
* Changes
* Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral
branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
changed](https://github.com/jax-ml/jax/issues/19085) so that
mapping over keys results in random generation only from the first
key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays
rather than `jax.random.PRNGKey`.
* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.clear_backends` is deprecated as it does not necessarily do what
its name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}`jax.clear_caches` if you only want to clean up compilation caches.
For backward compatibility or you really need to switch/reinitialize the
default backend, use {func}`jax.extend.backend.clear_backends`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* The `jax.experimental.host_callback` module is deprecated.
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
new callbacks. See {jax-issue}`#20385` for a discussion.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
This flag was long deprecated and did nothing; its use was a no-op.
* The previously-deprecated imports `jax.interpreters.ad.config` and
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
and `jax.extend.source_info_util` instead.
2024-04-05 14:21:33 +05:30
* JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific
JAX serialization version lower than 9.
## jaxlib 0.4.26 (April 3, 2024)
* Changes
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
dropped.
2024-04-03 14:38:58 -04:00
* JAX now supports NumPy 2.0.
## jax 0.4.25 (Feb 26, 2024)
* New Features
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
2024-02-08 14:17:31 -08:00
import support (requires jaxlib 0.4.24).
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.
* Added {mod}`jax.tree` module, with a more convenient interface for referencing functions
in {mod}`jax.tree_util`.
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.
2024-02-08 14:17:31 -08:00
* Changes
* Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable to `"0"`.
* Several deprecated APIs in {mod}`jax.interpreters.xla` that were removed in v0.4.24
have been re-added in v0.4.25, including `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`, `TranslationRule`,
`TranslationContext`, and `XLAOp`. These are still considered deprecated, and
will be removed again in the future when better replacements are available.
Refer to {jax-issue}`#19816` for discussion.
2024-02-08 14:17:31 -08:00
* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
solves.
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
* The previously deprecated configuration APIs have been removed
following a standard 3 months deprecation cycle (see {ref}`api-compatibility`).
These include
* the `jax.config.config` object and
* the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`.
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
To configure JAX use `import jax` and then reference the config object
via `jax.config`.
* The minimum jaxlib version is now 0.4.20.
2024-02-08 14:17:31 -08:00
## jaxlib 0.4.25 (Feb 26, 2024)
2024-02-08 14:17:31 -08:00
## jax 0.4.24 (Feb 6, 2024)
* Changes
* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Several changes to the handling of shape polymorphism (used in
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`):
* cleaner pretty-printing of symbolic expressions ({jax-issue}`#19227`)
* added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`#19235`) we now
consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
cannot interact, e.g., in arithmetic operations.
Scopes are introduced by {func}`jax.experimental.jax2tf.convert`,
{func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`.
The scope of a symbolic expression `e` can be read with `e.scope` and passed
2024-04-05 14:21:33 +05:30
into the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior
changes)
* improved the error messages for inconclusive inequality comparisons
({jax-issue}`#19235`).
* the `core.non_negative_dim` API (introduced recently)
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* the `shape_poly.is_poly_dim` is deprecated in favor of `export.is_symbolic_dim`
({jax-issue}`#19282`).
* the `export.args_specs` is deprecated in favor of `export.symbolic_args_specs
({jax-issue}`#19283`).
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
2024-01-13 22:17:21 -05:00
* Added {func}`jax.scipy.stats.sem`.
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* {func}`jax.numpy.sign` now returns `x / abs(x)` for nonzero complex inputs. This is
consistent with the behavior of {func}`numpy.sign` in NumPy version 2.0.
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of {func}`scipy.special.logsumexp` in SciPy v1.13.
* JAX now supports the bool DLPack type for both import and export.
Previously bool values could not be imported and were exported as integers.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes:
* From {mod}`jax.core`: `TracerArrayConversionError`,
`TracerIntegerConversionError`, `UnexpectedTracerError`,
`as_hashable_function`, `collections`, `dtypes`, `lu`, `map`,
`namedtuple`, `partial`, `pp`, `ref`, `safe_zip`, `safe_map`,
`source_info_util`, `total_ordering`, `traceback_util`, `tuple_delete`,
`tuple_insert`, and `zip`.
* From {mod}`jax.lax`: `dtypes`, `itertools`, `naryop`, `naryop_dtype_rule`,
`standard_abstract_eval`, `standard_naryop`, `standard_primitive`,
`standard_unop`, `unop`, and `unop_dtype_rule`.
* The `jax.linear_util` submodule and all its contents.
* The `jax.prng` submodule and all its contents.
* From {mod}`jax.random`: `PRNGKeyArray`, `KeyArray`, `default_prng_impl`,
`threefry_2x32`, `threefry2x32_key`, `threefry2x32_p`, `rbg_key`, and
`unsafe_rbg_key`.
* From {mod}`jax.tree_util`: `register_keypaths`, `AttributeKeyPathEntry`, and
`GetItemKeyPathEntry`.
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`,
`register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`,
`axis_groups`, `ShapedArray`, `ConcreteArray`, `AxisEnv`, `backend_compile`,
and `XLAOp`.
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.
* `bool(empty_array)` now raises an error rather than returning `False`. This
previously raised a deprecation warning, and follows a similar change in NumPy.
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
2024-01-18 13:13:47 -08:00
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.
## jaxlib 0.4.24 (Feb 6, 2024)
* Changes
* JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been
dropped.
* `cost_analysis` now works with cross-compiled `Compiled` objects (i.e. when
using `.lower().compile()` with a topology object, e.g., to compile for
Cloud TPU from a non-TPU computer).
2024-02-08 14:17:31 -08:00
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jax 0.4.25).
## jax 0.4.23 (Dec 13, 2023)
## jaxlib 0.4.23 (Dec 13, 2023)
* Fixed a bug that caused verbose logging from the GPU compiler during
compilation.
## jax 0.4.22 (Dec 13, 2023)
* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
## jaxlib 0.4.22 (Dec 13, 2023)
## jax 0.4.21 (Dec 4 2023)
2023-11-14 23:52:41 -05:00
* New Features
* Added {obj}`jax.nn.squareplus`.
* Changes
* The minimum jaxlib version is now 0.4.19.
* Released wheels are built now with clang instead of gcc.
* Enforce that the device backend has not been initialized prior to calling `jax.distributed.initialize()`.
* Automate arguments to `jax.distributed.initialize()` in cloud TPU environments.
* Deprecations
* The previously-deprecated `sym_pos` argument has been removed from
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device()` method of JAX arrays is deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
## jaxlib 0.4.21 (Dec 4 2023)
Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs. The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions. Timings on T4 GPU: Before: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 263587 ns 242274 ns 2780 svd/m:2/n:1 335561 ns 298238 ns 2303 svd/m:5/n:1 337784 ns 299841 ns 2304 svd/m:10/n:1 339184 ns 300703 ns 2311 svd/m:100/n:1 359826 ns 320088 ns 2159 svd/m:500/n:1 376124 ns 338660 ns 2076 svd/m:800/n:1 375779 ns 335590 ns 2060 svd/m:1000/n:1 419171 ns 341487 ns 2072 svd/m:1/n:2 307564 ns 270663 ns 2544 svd/m:2/n:2 320928 ns 283601 ns 2487 svd/m:5/n:2 377373 ns 344228 ns 2035 svd/m:10/n:2 380557 ns 349412 ns 1953 svd/m:100/n:2 435465 ns 403496 ns 1722 svd/m:500/n:2 444610 ns 410913 ns 1680 svd/m:800/n:2 454493 ns 416495 ns 1665 svd/m:1000/n:2 492110 ns 420539 ns 1665 svd/m:1/n:5 307316 ns 275833 ns 2531 svd/m:2/n:5 374318 ns 341432 ns 2086 svd/m:5/n:5 512928 ns 470293 ns 1361 svd/m:10/n:5 589330 ns 537070 ns 1353 svd/m:100/n:5 620164 ns 580166 ns 1193 svd/m:500/n:5 636424 ns 593692 ns 1180 svd/m:800/n:5 635545 ns 595016 ns 1181 svd/m:1000/n:5 672443 ns 597387 ns 1115 svd/m:1/n:10 310013 ns 273998 ns 2520 svd/m:2/n:10 370451 ns 334489 ns 2105 svd/m:5/n:10 560037 ns 522223 ns 1274 svd/m:10/n:10 572868 ns 535388 ns 1304 svd/m:100/n:10 959802 ns 918258 ns 765 svd/m:500/n:10 955958 ns 909778 ns 758 svd/m:800/n:10 924104 ns 879512 ns 777 svd/m:1000/n:10 950140 ns 883493 ns 775 svd/m:1/n:100 351237 ns 315554 ns 2198 svd/m:2/n:100 426883 ns 390089 ns 1792 svd/m:5/n:100 601557 ns 564493 ns 1255 svd/m:10/n:100 920819 ns 880011 ns 787 svd/m:100/n:100 7902281 ns 7229220 ns 95 svd/m:500/n:100 9720727 ns 9040679 ns 79 svd/m:800/n:100 9856378 ns 8998050 ns 79 svd/m:1000/n:100 9721017 ns 9086414 ns 79 svd/m:1/n:500 371171 ns 334217 ns 2117 svd/m:2/n:500 449165 ns 411499 ns 1700 svd/m:5/n:500 620354 ns 581866 ns 1185 svd/m:10/n:500 892375 ns 847239 ns 833 svd/m:100/n:500 9564810 ns 8867540 ns 79 svd/m:500/n:500 111924035 ns 104078023 ns 7 svd/m:800/n:500 147777319 ns 142730412 ns 5 svd/m:1000/n:500 154205084 ns 149740209 ns 5 svd/m:1/n:800 372122 ns 334212 ns 2119 svd/m:2/n:800 456672 ns 419260 ns 1680 svd/m:5/n:800 691208 ns 626003 ns 1190 svd/m:10/n:800 1017694 ns 941480 ns 730 svd/m:100/n:800 9892683 ns 9091043 ns 76 svd/m:500/n:800 144134235 ns 139129722 ns 5 svd/m:800/n:800 342790246 ns 333299774 ns 2 svd/m:1000/n:800 432820082 ns 427978978 ns 2 svd/m:1/n:1000 372785 ns 335745 ns 1805 svd/m:2/n:1000 451946 ns 413341 ns 1668 svd/m:5/n:1000 618475 ns 577213 ns 1169 svd/m:10/n:1000 907729 ns 863335 ns 808 svd/m:100/n:1000 9868543 ns 9116870 ns 76 svd/m:500/n:1000 156777811 ns 152042065 ns 5 svd/m:800/n:1000 429704070 ns 424677592 ns 2 svd/m:1000/n:1000 654864311 ns 642693162 ns 1 After: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 265980 ns 245433 ns 2791 svd/m:2/n:1 340203 ns 302783 ns 2288 svd/m:5/n:1 337807 ns 301916 ns 2286 svd/m:10/n:1 338064 ns 302441 ns 2297 svd/m:100/n:1 335444 ns 298440 ns 2327 svd/m:500/n:1 338025 ns 302096 ns 2272 svd/m:800/n:1 328382 ns 291740 ns 2252 svd/m:1000/n:1 397494 ns 310905 ns 2239 svd/m:1/n:2 310464 ns 274507 ns 2535 svd/m:2/n:2 319999 ns 284247 ns 2515 svd/m:5/n:2 373435 ns 335919 ns 2069 svd/m:10/n:2 376327 ns 339327 ns 2056 svd/m:100/n:2 385061 ns 349258 ns 2003 svd/m:500/n:2 392352 ns 355735 ns 1932 svd/m:800/n:2 410736 ns 370677 ns 1881 svd/m:1000/n:2 494326 ns 405603 ns 1721 svd/m:1/n:5 316735 ns 277292 ns 2538 svd/m:2/n:5 383748 ns 342218 ns 2077 svd/m:5/n:5 494204 ns 454309 ns 1476 svd/m:10/n:5 547017 ns 508184 ns 1371 svd/m:100/n:5 514537 ns 476761 ns 1460 svd/m:500/n:5 544656 ns 504877 ns 1381 svd/m:800/n:5 642590 ns 599314 ns 1159 svd/m:1000/n:5 706166 ns 621209 ns 1106 svd/m:1/n:10 310825 ns 274374 ns 2511 svd/m:2/n:10 381316 ns 344202 ns 2094 svd/m:5/n:10 565469 ns 526759 ns 1266 svd/m:10/n:10 576111 ns 537286 ns 1299 svd/m:100/n:10 653250 ns 613392 ns 1137 svd/m:500/n:10 690532 ns 645828 ns 1080 svd/m:800/n:10 763924 ns 723677 ns 959 svd/m:1000/n:10 940342 ns 855517 ns 818 svd/m:1/n:100 306134 ns 271533 ns 2526 svd/m:2/n:100 374680 ns 339298 ns 2071 svd/m:5/n:100 576926 ns 539062 ns 1228 svd/m:10/n:100 656806 ns 615171 ns 1123 svd/m:100/n:100 3295164 ns 3138621 ns 223 svd/m:500/n:100 4269347 ns 4166000 ns 168 svd/m:800/n:100 4656541 ns 4522247 ns 154 svd/m:1000/n:100 6479223 ns 6354578 ns 112 svd/m:1/n:500 329966 ns 289083 ns 2440 svd/m:2/n:500 407535 ns 366794 ns 1947 svd/m:5/n:500 567367 ns 522809 ns 1336 svd/m:10/n:500 712307 ns 657608 ns 1065 svd/m:100/n:500 4262986 ns 4169907 ns 167 svd/m:500/n:500 28824720 ns 28650258 ns 25 svd/m:800/n:500 29330139 ns 28677269 ns 25 svd/m:1000/n:500 30848037 ns 30089216 ns 23 svd/m:1/n:800 328620 ns 289181 ns 2329 svd/m:2/n:800 419052 ns 379483 ns 1876 svd/m:5/n:800 587366 ns 546979 ns 1269 svd/m:10/n:800 830762 ns 787923 ns 893 svd/m:100/n:800 4763633 ns 4595738 ns 152 svd/m:500/n:800 30447861 ns 29949714 ns 24 svd/m:800/n:800 94188958 ns 93488372 ns 8 svd/m:1000/n:800 94701529 ns 93394677 ns 7 svd/m:1/n:1000 351102 ns 313099 ns 2218 svd/m:2/n:1000 446543 ns 407807 ns 1708 svd/m:5/n:1000 661152 ns 616174 ns 1129 svd/m:10/n:1000 915743 ns 873397 ns 802 svd/m:100/n:1000 6434730 ns 6282779 ns 113 svd/m:500/n:1000 30244321 ns 29684290 ns 24 svd/m:800/n:1000 92727423 ns 91477078 ns 8 svd/m:1000/n:1000 169500709 ns 168358420 ns 4 PiperOrigin-RevId: 582041508
2023-11-13 12:03:36 -08:00
* Changes
* In preparation for adding distributed CPU support, JAX now treats CPU
devices identically to GPU and TPU devices, that is:
* `jax.devices()` includes all devices present in a distributed job, even
those not local to the current process. `jax.local_devices()` still only
includes devices local to the current process, so if the change to
`jax.devices()` breaks you, you most likely want to use
`jax.local_devices()` instead.
* CPU devices now receive a globally unique ID number within a distributed
job; previously CPU devices would receive a process-local ID number.
* The `process_index` of each CPU device will now match any GPU or TPU
devices within the same process; previously the `process_index` of a CPU
device was always 0.
Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs. The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions. Timings on T4 GPU: Before: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 263587 ns 242274 ns 2780 svd/m:2/n:1 335561 ns 298238 ns 2303 svd/m:5/n:1 337784 ns 299841 ns 2304 svd/m:10/n:1 339184 ns 300703 ns 2311 svd/m:100/n:1 359826 ns 320088 ns 2159 svd/m:500/n:1 376124 ns 338660 ns 2076 svd/m:800/n:1 375779 ns 335590 ns 2060 svd/m:1000/n:1 419171 ns 341487 ns 2072 svd/m:1/n:2 307564 ns 270663 ns 2544 svd/m:2/n:2 320928 ns 283601 ns 2487 svd/m:5/n:2 377373 ns 344228 ns 2035 svd/m:10/n:2 380557 ns 349412 ns 1953 svd/m:100/n:2 435465 ns 403496 ns 1722 svd/m:500/n:2 444610 ns 410913 ns 1680 svd/m:800/n:2 454493 ns 416495 ns 1665 svd/m:1000/n:2 492110 ns 420539 ns 1665 svd/m:1/n:5 307316 ns 275833 ns 2531 svd/m:2/n:5 374318 ns 341432 ns 2086 svd/m:5/n:5 512928 ns 470293 ns 1361 svd/m:10/n:5 589330 ns 537070 ns 1353 svd/m:100/n:5 620164 ns 580166 ns 1193 svd/m:500/n:5 636424 ns 593692 ns 1180 svd/m:800/n:5 635545 ns 595016 ns 1181 svd/m:1000/n:5 672443 ns 597387 ns 1115 svd/m:1/n:10 310013 ns 273998 ns 2520 svd/m:2/n:10 370451 ns 334489 ns 2105 svd/m:5/n:10 560037 ns 522223 ns 1274 svd/m:10/n:10 572868 ns 535388 ns 1304 svd/m:100/n:10 959802 ns 918258 ns 765 svd/m:500/n:10 955958 ns 909778 ns 758 svd/m:800/n:10 924104 ns 879512 ns 777 svd/m:1000/n:10 950140 ns 883493 ns 775 svd/m:1/n:100 351237 ns 315554 ns 2198 svd/m:2/n:100 426883 ns 390089 ns 1792 svd/m:5/n:100 601557 ns 564493 ns 1255 svd/m:10/n:100 920819 ns 880011 ns 787 svd/m:100/n:100 7902281 ns 7229220 ns 95 svd/m:500/n:100 9720727 ns 9040679 ns 79 svd/m:800/n:100 9856378 ns 8998050 ns 79 svd/m:1000/n:100 9721017 ns 9086414 ns 79 svd/m:1/n:500 371171 ns 334217 ns 2117 svd/m:2/n:500 449165 ns 411499 ns 1700 svd/m:5/n:500 620354 ns 581866 ns 1185 svd/m:10/n:500 892375 ns 847239 ns 833 svd/m:100/n:500 9564810 ns 8867540 ns 79 svd/m:500/n:500 111924035 ns 104078023 ns 7 svd/m:800/n:500 147777319 ns 142730412 ns 5 svd/m:1000/n:500 154205084 ns 149740209 ns 5 svd/m:1/n:800 372122 ns 334212 ns 2119 svd/m:2/n:800 456672 ns 419260 ns 1680 svd/m:5/n:800 691208 ns 626003 ns 1190 svd/m:10/n:800 1017694 ns 941480 ns 730 svd/m:100/n:800 9892683 ns 9091043 ns 76 svd/m:500/n:800 144134235 ns 139129722 ns 5 svd/m:800/n:800 342790246 ns 333299774 ns 2 svd/m:1000/n:800 432820082 ns 427978978 ns 2 svd/m:1/n:1000 372785 ns 335745 ns 1805 svd/m:2/n:1000 451946 ns 413341 ns 1668 svd/m:5/n:1000 618475 ns 577213 ns 1169 svd/m:10/n:1000 907729 ns 863335 ns 808 svd/m:100/n:1000 9868543 ns 9116870 ns 76 svd/m:500/n:1000 156777811 ns 152042065 ns 5 svd/m:800/n:1000 429704070 ns 424677592 ns 2 svd/m:1000/n:1000 654864311 ns 642693162 ns 1 After: ------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------ svd/m:1/n:1 265980 ns 245433 ns 2791 svd/m:2/n:1 340203 ns 302783 ns 2288 svd/m:5/n:1 337807 ns 301916 ns 2286 svd/m:10/n:1 338064 ns 302441 ns 2297 svd/m:100/n:1 335444 ns 298440 ns 2327 svd/m:500/n:1 338025 ns 302096 ns 2272 svd/m:800/n:1 328382 ns 291740 ns 2252 svd/m:1000/n:1 397494 ns 310905 ns 2239 svd/m:1/n:2 310464 ns 274507 ns 2535 svd/m:2/n:2 319999 ns 284247 ns 2515 svd/m:5/n:2 373435 ns 335919 ns 2069 svd/m:10/n:2 376327 ns 339327 ns 2056 svd/m:100/n:2 385061 ns 349258 ns 2003 svd/m:500/n:2 392352 ns 355735 ns 1932 svd/m:800/n:2 410736 ns 370677 ns 1881 svd/m:1000/n:2 494326 ns 405603 ns 1721 svd/m:1/n:5 316735 ns 277292 ns 2538 svd/m:2/n:5 383748 ns 342218 ns 2077 svd/m:5/n:5 494204 ns 454309 ns 1476 svd/m:10/n:5 547017 ns 508184 ns 1371 svd/m:100/n:5 514537 ns 476761 ns 1460 svd/m:500/n:5 544656 ns 504877 ns 1381 svd/m:800/n:5 642590 ns 599314 ns 1159 svd/m:1000/n:5 706166 ns 621209 ns 1106 svd/m:1/n:10 310825 ns 274374 ns 2511 svd/m:2/n:10 381316 ns 344202 ns 2094 svd/m:5/n:10 565469 ns 526759 ns 1266 svd/m:10/n:10 576111 ns 537286 ns 1299 svd/m:100/n:10 653250 ns 613392 ns 1137 svd/m:500/n:10 690532 ns 645828 ns 1080 svd/m:800/n:10 763924 ns 723677 ns 959 svd/m:1000/n:10 940342 ns 855517 ns 818 svd/m:1/n:100 306134 ns 271533 ns 2526 svd/m:2/n:100 374680 ns 339298 ns 2071 svd/m:5/n:100 576926 ns 539062 ns 1228 svd/m:10/n:100 656806 ns 615171 ns 1123 svd/m:100/n:100 3295164 ns 3138621 ns 223 svd/m:500/n:100 4269347 ns 4166000 ns 168 svd/m:800/n:100 4656541 ns 4522247 ns 154 svd/m:1000/n:100 6479223 ns 6354578 ns 112 svd/m:1/n:500 329966 ns 289083 ns 2440 svd/m:2/n:500 407535 ns 366794 ns 1947 svd/m:5/n:500 567367 ns 522809 ns 1336 svd/m:10/n:500 712307 ns 657608 ns 1065 svd/m:100/n:500 4262986 ns 4169907 ns 167 svd/m:500/n:500 28824720 ns 28650258 ns 25 svd/m:800/n:500 29330139 ns 28677269 ns 25 svd/m:1000/n:500 30848037 ns 30089216 ns 23 svd/m:1/n:800 328620 ns 289181 ns 2329 svd/m:2/n:800 419052 ns 379483 ns 1876 svd/m:5/n:800 587366 ns 546979 ns 1269 svd/m:10/n:800 830762 ns 787923 ns 893 svd/m:100/n:800 4763633 ns 4595738 ns 152 svd/m:500/n:800 30447861 ns 29949714 ns 24 svd/m:800/n:800 94188958 ns 93488372 ns 8 svd/m:1000/n:800 94701529 ns 93394677 ns 7 svd/m:1/n:1000 351102 ns 313099 ns 2218 svd/m:2/n:1000 446543 ns 407807 ns 1708 svd/m:5/n:1000 661152 ns 616174 ns 1129 svd/m:10/n:1000 915743 ns 873397 ns 802 svd/m:100/n:1000 6434730 ns 6282779 ns 113 svd/m:500/n:1000 30244321 ns 29684290 ns 24 svd/m:800/n:1000 92727423 ns 91477078 ns 8 svd/m:1000/n:1000 169500709 ns 168358420 ns 4 PiperOrigin-RevId: 582041508
2023-11-13 12:03:36 -08:00
* On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to
1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
* Bug fixes
* Fixed error/hang when an array with non-finite values is passed to a
non-symmetric eigendecomposition (#18226). Arrays with non-finite values now
produce arrays full of NaNs as outputs.
## jax 0.4.20 (Nov 2, 2023)
## jaxlib 0.4.20 (Nov 2, 2023)
* Bug fixes
* Fixed some type confusion between E4M3 and E5M2 float8 types.
## jax 0.4.19 (Oct 19, 2023)
2023-10-10 08:46:36 -07:00
* New Features
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
are convertible to JAX dtypes.
2023-10-20 16:47:46 -04:00
* Added `jax.numpy.fill_diagonal`.
2023-10-10 08:46:36 -07:00
* Changes
* JAX now requires SciPy 1.9 or newer.
* Bug fixes
* Only process 0 in a multicontroller distributed JAX program will write
persistent compilation cache entries. This fixes write contention if the
2024-04-05 14:21:33 +05:30
cache is placed on a network file system such as GCS.
* The version check for cusolver and cufft no longer considers the patch
versions when determining if the installed version of these libraries is at
least as new as the versions against which JAX was built.
## jaxlib 0.4.19 (Oct 19, 2023)
* Changes
* jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in `LD_LIBRARY_PATH`. If this
causes problems and the intent is to use a system-installed CUDA, the fix is
to remove the pip installed CUDA library packages.
## jax 0.4.18 (Oct 6, 2023)
## jaxlib 0.4.18 (Oct 6, 2023)
* Changes
* CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended `cuda12_pip` installation, NCCL should be installed
automatically. Currently, NCCL 2.16 or newer is required.
* We now provide Linux aarch64 wheels, both with and without NVIDIA GPU
support.
2024-01-03 13:03:47 -08:00
* {meth}`jax.Array.item` now supports optional index arguments.
* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
been deprecated, and will be removed in a future release.
* `jax.lax.dtypes`: use `jax.dtypes` instead.
* `jax.lax.itertools`: use `itertools` instead.
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
internal utilities, now deprecated without replacement.
* Bug fixes
* Fixed Cloud TPU regression where compilation would OOM due to smem.
## jax 0.4.17 (Oct 3, 2023)
2023-10-03 13:48:16 -07:00
* New features
* Added new {func}`jax.numpy.bitwise_count` function, matching the API of the similar
2023-10-03 13:48:16 -07:00
function recently added to NumPy.
* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
* Changes:
* CUDA: JAX now verifies that the CUDA libraries it finds are at least as new
as the CUDA libraries that JAX was built against. If older libraries are
found, JAX raises an exception since that is preferable to mysterious
failures and crashes.
* Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was
not specified.
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.
* Most `jax.numpy` functions and attributes now have fully-defined type stubs.
Previously many of these were treated as `Any` by static type checkers like
`mypy` and `pytype`.
## jaxlib 0.4.17 (Oct 3, 2023)
* Changes:
* Python 3.12 wheels were added in this release.
* The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
* Bug fixes:
* Fixed log spam from ABSL when the JAX CPU backend was initialized.
## jax 0.4.16 (Sept 18, 2023)
* Changes
2023-08-10 14:58:18 -07:00
* Added {class}`jax.numpy.ufunc`, as well as {func}`jax.numpy.frompyfunc`, which can convert
any scalar-valued function into a {func}`numpy.ufunc`-like object, with methods such as
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.
* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* The backwards compatibility flag `--jax_host_callback_ad_transforms`
introduced in December 2021, has been removed.
* Deprecations:
* Several `jax.numpy` APIs have been deprecated following
[NumPy NEP-52](https://numpy.org/neps/nep-0052-python-api-cleanup.html):
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
2023-08-22 13:12:49 -07:00
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Use the built-in `math.prod` instead.
* A number of exports from `jax.interpreters.xla` related to defining
HLO lowering rules for custom JAX primitives have been deprecated. Custom
primitives should be defined using the StableHLO lowering utilities in
`jax.interpreters.mlir` instead.
* The following previously-deprecated functions have been removed after a
three-month deprecation period:
* `jax.abstract_arrays.ShapedArray`: use `jax.core.ShapedArray`.
* `jax.abstract_arrays.raise_to_shaped`: use `jax.core.raise_to_shaped`.
* `jax.numpy.alltrue`: use `jax.numpy.all`.
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
* Deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)` for
runtime detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
* `jax.experimental.pjit.with_sharding_constraint` is deprecated. Use
`jax.lax.with_sharding_constraint` instead.
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`
have been removed. Opaque dtypes have been renamed to Extended dtypes; use
`jnp.issubdtype(dtype, jax.dtypes.extended)` instead (available since jax v0.4.14).
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.16 (Sept 18, 2023)
* Changes:
* Sparse CSR matrix multiplications via the experimental jax sparse APIs
no longer uses a deterministic algorithm on NVIDIA GPUs. This change was
made to improve compatibility with CUDA 12.2.1.
* Bug fixes:
* Fixed a crash on Windows due to a fatal LLVM error related to out-of-order
sections and IMAGE_REL_AMD64_ADDR32NB relocations
(https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
## jax 0.4.14 (July 27, 2023)
* Changes
* `jax.jit` takes `donate_argnames` as an argument. It's semantics are similar
to `static_argnames`.
If neither donate_argnums nor donate_argnames is provided, no
arguments are donated. If donate_argnums is not provided but
donate_argnames is, or vice versa, JAX uses
`inspect.signature(fun)` to find any positional arguments that
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
* {func}`jax.random.gamma` has been re-factored to a more efficient algorithm
with more robust endpoint behavior ({jax-issue}`#16779`). This means that the
sequence of values returned for a given `key` will change between JAX v0.4.13
and v0.4.14 for `gamma` and related samplers (including {func}`jax.random.ball`,
{func}`jax.random.beta`, {func}`jax.random.chisquare`, {func}`jax.random.dirichlet`,
{func}`jax.random.generalized_normal`, {func}`jax.random.loggamma`, {func}`jax.random.t`).
* Deletions
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
it has been more than 3 months since their deprecation. Please use
`in_shardings` and `out_shardings` as the replacement.
This is a safe and trivial name replacement. It does not change any of the
current pjit semantics and doesn't break any code.
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.
* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* The following `jax.Array` methods have been removed, after being deprecated
in JAX v0.4.5:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
2023-07-11 12:42:32 -07:00
* The following APIs have been removed after previous deprecation:
* `jax.ad`: use {mod}`jax.interpreters.ad`.
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
* `jax.xla`: use {mod}`jax.interpreters.xla`.
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.
* `jax.stages.Compiled.compiler_ir`: use {func}`jax.stages.Compiled.as_text`.
* Breaking changes
2023-07-07 12:07:44 -07:00
* JAX now requires ml_dtypes version 0.2.0 or newer.
* To fix a corner case, calls to {func}`jax.lax.cond` with five
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
[#16413](https://github.com/jax-ml/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.
* New features
* JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}`#16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
## jaxlib 0.4.14 (July 27, 2023)
* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html
## jax 0.4.13 (June 22, 2023)
* Changes
* `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
* If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
2023-06-22 16:30:04 -04:00
* Executable.cost_analysis() works on Cloud TPU
* Added a warning if a non-allowlisted `jaxlib` plugin is in use.
* Added `jax.tree_util.tree_leaves_with_path`.
* `None` is not a valid input to
`jax.experimental.multihost_utils.host_local_array_to_global_array` or
`jax.experimental.multihost_utils.global_array_to_host_local_array`.
Please use `jax.sharding.PartitionSpec()` if you wanted to replicate your
input.
* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named `cudnn89` instead of `cudnn88`.
* Deprecations
* The `native_serialization_strict_checks` parameter to
{func}`jax.experimental.jax2tf.convert` is deprecated in favor of the
new `native_serializaation_disabled_checks` ({jax-issue}`#16347`).
## jaxlib 0.4.13 (June 22, 2023)
2023-06-22 16:30:04 -04:00
* Changes
* Added Windows CPU-only wheels to the `jaxlib` Pypi release.
* Bug fixes
* `__cuda_array_interface__` was broken in previous jaxlib versions and is now
fixed ({jax-issue}`16440`).
2023-06-22 16:30:04 -04:00
* Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
## jax 0.4.12 (June 8, 2023)
* Changes
* Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`
* Deprecations
2023-06-06 07:32:35 -07:00
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.sharding.OpShardingSharding` has been removed since it has been 3
months since it was deprecated.
## jaxlib 0.4.12 (June 8, 2023)
2023-06-05 09:42:12 -04:00
* Changes
* Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous
2023-06-05 09:42:12 -04:00
versions of jaxlib should work on Hopper but would have a long
JIT-compilation delay the first time a JAX operation was executed.
* Bug fixes
* Fixes incorrect source line information in JAX-generated Python tracebacks
under Python 3.11.
* Fixes crash when printing local variables of frames in JAX-generated Python
tracebacks (#16027).
## jax 0.4.11 (May 31, 2023)
* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
* `jax.interpreters.xla.Buffer`: use `jax.Array`.
* `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
* `jax.interpreters.xla.device_put`: use `jax.device_put`.
* `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.
## jaxlib 0.4.11 (May 31, 2023)
* Changes
* Added `memory_stats()` method to `Device`s. If supported, this returns a
dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if
the platform doesn't support memory statistics. The exact stats returned may
vary across platforms. Currently only implemented on Cloud TPU.
* Readded support for the Python buffer protocol (`memoryview`) on CPU
devices.
## jax 0.4.10 (May 11, 2023)
## jaxlib 0.4.10 (May 11, 2023)
* Changes
* Fixed `'apple-m1' is not a recognized processor for this target (ignoring
processor)` issue that prevented previous release from running on Mac M1.
## jax 0.4.9 (May 9, 2023)
* Changes
* The flags experimental_cpp_jit, experimental_cpp_pjit and
experimental_cpp_pmap have been removed.
They are now always on.
* Accuracy of singular value decomposition (SVD) on TPU has been improved
(requires jaxlib 0.4.9).
* Deprecations
* `jax.experimental.gda_serialization` is deprecated and has been renamed to
`jax.experimental.array_serialization`.
Please change your imports to use `jax.experimental.array_serialization`.
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
deprecated. Please use `in_shardings` and `out_shardings` respectively.
2023-03-30 13:18:28 -07:00
* The function `jax.numpy.msort` has been removed. It has been deprecated since
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
* `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation`
since they were only used with sharded_jit and sharded_jit is long gone.
* `instantiate_const_outputs` argument has been removed from `jax.xla_computation`
since it has been unused for a very long time.
## jaxlib 0.4.9 (May 9, 2023)
## jax 0.4.8 (March 29, 2023)
* Breaking changes
* A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
* {func}`jax.debug.print`, {func}`jax.debug.callback`, and
{func}`jax.debug.breakpoint()` now work on Cloud TPU
* Automatic TPU memory defragmentation
{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
tracker](https://github.com/jax-ml/jax/issues).
* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
* Deprecations
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
## jax 0.4.7 (March 27, 2023)
* Changes
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
`jax.config.jax_array` cannot be disabled anymore.
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
* {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization`
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
* JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
* JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
for which it is an alias.
* The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use
`jax.Array` instead.
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.
## jaxlib 0.4.7 (March 27, 2023)
Changes:
* jaxlib now depends on `ml_dtypes`, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
## jax 0.4.6 (Mar 9, 2023)
2023-03-03 18:05:37 -08:00
* Changes
2023-03-05 13:14:57 -08:00
* `jax.tree_util` now contain a set of APIs that allow user to define keys for their
2023-03-03 18:05:37 -08:00
custom pytree node. This includes:
* `tree_flatten_with_path` that flattens a tree and return not only each leaf but
2023-03-03 18:05:37 -08:00
also their key paths.
* `tree_map_with_path` that can map a function that takes the key path as an argument.
* `register_pytree_with_keys` to register how the key path and leaves should looks
2023-03-03 18:05:37 -08:00
like in a custom pytree node.
* `keystr` that pretty-prints a key path.
* {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`)
that can be used to declare the output shape and type of the result. This enables
{func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`).
2023-03-03 18:05:37 -08:00
* Deprecations
* The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months
from Mar 10 2023:
2023-03-05 13:14:57 -08:00
* `register_keypaths`: use {func}`jax.tree_util.register_pytree_with_keys` instead.
2023-03-03 18:05:37 -08:00
* `AttributeKeyPathEntry` : use `GetAttrKey` instead.
* `GetitemKeyPathEntry` : use `SequenceKey` or `DictKey` instead.
## jaxlib 0.4.6 (Mar 9, 2023)
## jax 0.4.5 (Mar 2, 2023)
* Deprecations
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
* The following `jax.Array` methods are deprecated and will be removed 3 months from
Feb 23 2023:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
## jax 0.4.4 (Feb 16, 2023)
2023-02-08 10:08:57 -08:00
* Changes
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, `jit` was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
Please use `shardings` instead. There is no change needed if you were using
`axis_resources` as an arg. If you were using it as a kwarg, then please
use `shardings` instead. `axis_resources` will be removed after 3 months
from Feb 13, 2023.
2023-02-13 15:53:08 -08:00
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions.
* The following names have been deprecated:
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.experimental.maps.Mesh`. Use `jax.sharding.Mesh`
instead.
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcasting the output against non-scalar `initial`
values was an unintentional implementation detail ({jax-issue}`#14446`).
## jaxlib 0.4.4 (Feb 16, 2023)
* Breaking changes
* Support for NVIDIA Kepler series GPUs has been removed from the default
`jaxlib` builds. If Kepler support is needed, it is still possible to
build `jaxlib` from source with Kepler support (via the
`--cuda_compute_capabilities=sm_35` option to `build.py`), however note
that CUDA 12 has completely dropped support for Kepler GPUs.
2023-02-08 10:08:57 -08:00
## jax 0.4.3 (Feb 8, 2023)
* Breaking changes
2023-02-08 10:08:57 -08:00
* Deleted {func}`jax.scipy.linalg.polar_unitary`, which was a deprecated JAX
extension to the scipy API. Use {func}`jax.scipy.linalg.polar` instead.
2023-01-20 19:04:46 +00:00
2023-02-07 12:07:00 -05:00
* Changes
* Added {func}`jax.scipy.stats.rankdata`.
2023-02-08 10:08:57 -08:00
## jaxlib 0.4.3 (Feb 8, 2023)
2023-02-23 11:56:48 -08:00
* `jax.Array` now has the non-blocking `is_ready()` method, which returns `True`
if the array is ready (see also {func}`jax.block_until_ready`).
2023-01-20 19:04:46 +00:00
## jax 0.4.2 (Jan 24, 2023)
* Breaking changes
* Deleted `jax.experimental.callback`
* Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
`np.ndarray` now can raise errors when the result is used as a shape value
({jax-issue}`#14106`).
* jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}`14102`)
* Changes
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}`#13980`).
* Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
certain division operations resulted in errors in presence of symbolic dimensions
({jax-issue}`#14108`).
## jaxlib 0.4.2 (Jan 24, 2023)
2023-01-20 19:04:46 +00:00
* Changes
* Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring
automatic device memory defragmentation.
## jax 0.4.1 (Dec 13, 2022)
* Changes
2022-11-29 15:01:47 -08:00
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* We introduce `jax.Array` which is a unified array type that subsumes
`DeviceArray`, `ShardedDeviceArray`, and `GlobalDeviceArray` types in JAX.
The `jax.Array` type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify `jit` and
`pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some
breaking change to the `pjit` API. The [jax.Array migration
guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can
help you migrate your codebase to `jax.Array`. You can also look at the
[Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
`jax.config` options, which are used pervasively in JAX.
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
policy. It can be replaced with `jnp.sort(a, axis=0)`.
## jaxlib 0.4.1 (Dec 13, 2022)
2022-11-28 10:37:42 -08:00
* Changes
2022-11-29 15:01:47 -08:00
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
2022-11-28 10:37:42 -08:00
* The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
[GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for
more details.
* The deprecated method `.block_host_until_ready()` has been removed. Use
`.block_until_ready()` instead.
2022-11-28 10:37:42 -08:00
## jax 0.4.0 (Dec 12, 2022)
* The release was yanked.
## jaxlib 0.4.0 (Dec 12, 2022)
* The release was yanked.
## jax 0.3.25 (Nov 15, 2022)
* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
jaxlib > 0.3.24.
* New functions {func}`jax.lax.linalg.hessenberg`,
{func}`jax.lax.linalg.tridiagonal`, and
{func}`jax.lax.linalg.householder_product` were added. Householder reduction
is currently CPU-only and tridiagonal reductions are supported on CPU and
GPU only.
* The gradients of `svd` and `jax.numpy.linalg.pinv` are now computed more
economically for non-square matrices.
* Breaking Changes
* Deleted the `jax_experimental_name_stack` config option.
* Convert a string `axis_names` arguments to the
{class}`jax.experimental.maps.Mesh` constructor into a singleton tuple
instead of unpacking the string into a sequence of character axis names.
## jaxlib 0.3.25 (Nov 15, 2022)
* Changes
* Added support for tridiagonal reductions on CPU and GPU.
* Added support for upper Hessenberg reductions on CPU.
* Bugs
* Fixed a bug that meant that frames in tracebacks captured by JAX were
incorrectly mapped to source lines under Python 3.10+
## jax 0.3.24 (Nov 4, 2022)
* Changes
* JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time.
Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality.
2022-10-28 23:53:30 +00:00
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
used to limit the number of cache entries written to the persistent cache.
By default, computations that take 1 second or more to compile will be
cached.
* Added {func}`jax.scipy.stats.mode`.
* The default device order used by `pmap` on TPU if no order is specified now
matches `jax.devices()` for single-process jobs. Previously the
two orderings differed, which could lead to unnecessary copies or
out-of-memory errors. Requiring the orderings to agree simplifies matters.
* Breaking Changes
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of {jax-issue}`#7737`.
* Deprecations
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.
## jaxlib 0.3.24 (Nov 4, 2022)
* Changes
* Buffer donation now works on CPU. This may break code that marked buffers
for donation on CPU but relied on donation not being implemented.
## jax 0.3.23 (Oct 12, 2022)
2022-10-11 17:49:10 -07:00
* Changes
* Update Colab TPU driver version for new jaxlib release.
## jax 0.3.22 (Oct 11, 2022)
2022-10-05 01:01:31 +00:00
* Changes
2022-10-05 18:16:49 +00:00
* Add `JAX_PLATFORMS=tpu,cpu` as default setting in TPU initialization,
so JAX will raise an error if TPU cannot be initialized instead of falling
back to CPU. Set `JAX_PLATFORMS=''` to override this behavior and automatically
choose an available backend (the original default), or set `JAX_PLATFORMS=cpu`
to always use CPU regardless of if the TPU is available.
* Deprecations
* Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}`jax.test_util`.
## jaxlib 0.3.22 (Oct 11, 2022)
## jax 0.3.21 (Sep 30, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21).
* Changes
* The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue
if something goes wrong with the cache. Set
`JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` to revert this behavior.
2022-09-28 18:49:22 +00:00
## jax 0.3.20 (Sep 28, 2022)
2022-09-28 13:33:52 +00:00
* Bug fixes:
* Adds missing `.pyi` files that were missing from the previous release ({jax-issue}`#12536`).
* Fixes an incompatibility between `jax` 0.3.19 and the libtpu version it pinned ({jax-issue}`#12550`). Requires jaxlib 0.3.20.
* Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`).
2022-09-28 18:49:22 +00:00
## jaxlib 0.3.20 (Sep 28, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
2022-09-28 13:33:52 +00:00
* Bug fixes
* Fixes support for limiting the visible CUDA devices via
`jax_cuda_visible_devices` in distributed jobs. This functionality is needed for
the JAX/SLURM integration on GPU ({jax-issue}`#12533`).
## jax 0.3.19 (Sep 27, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19).
2022-09-28 13:33:52 +00:00
* Fixes required jaxlib version.
## jax 0.3.18 (Sep 26, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18).
2022-09-02 08:23:57 -07:00
* Changes
* Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}`#7733`) is stable and public. See [the
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
for {mod}`jax.stages`.
* Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks
and type annotations for array types in JAX. Notice that this included some subtle
changes to how `isinstance` works for {class}`jax.numpy.ndarray` for jax-internal
objects, as {class}`jax.numpy.ndarray` is now a simple alias of {class}`jax.Array`.
* Breaking changes
* `jax._src` is no longer imported into the public `jax` namespace.
This may break users that were using JAX internals.
* `jax.soft_pmap` has been deleted. Please use `pjit` or `xmap` instead.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period
would have been provided.
2022-08-31 11:37:09 -07:00
## jax 0.3.17 (Aug 31, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17).
2022-08-22 11:53:02 -07:00
* Bugs
* Fix corner case issue in gradient of `lax.pow` with an exponent of zero
({jax-issue}`12041`)
* Breaking changes
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
* Changes
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
2022-08-16 11:35:13 -07:00
* Deprecations:
* The deprecated `DeviceArray.tile()` method has been removed. Use {func}`jax.numpy.tile`
({jax-issue}`#11944`).
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
2022-08-11 17:10:21 -07:00
## jax 0.3.16
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main).
* Breaking changes
* Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to NumPy 1.20 or newer.
* Changes
2022-08-02 22:28:52 -07:00
* Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`.
* Added new documentation for [runtime value debugging](debugging/index)
* Deprecations
* {func}`jax.mask` {func}`jax.shapecheck` APIs have been removed.
See {jax-issue}`#11557`.
* {mod}`jax.experimental.loops` has been removed. See {jax-issue}`#10278`
for an alternative API.
* {func}`jax.tree_util.tree_multimap` has been removed. It has been deprecated since
JAX release 0.3.5, and {func}`jax.tree_util.tree_map` is a direct replacement.
* Removed `jax.experimental.stax`; it has long been a deprecated alias of
{mod}`jax.example_libraries.stax`.
* Removed `jax.experimental.optimizers`; it has long been a deprecated alias of
{mod}`jax.example_libraries.optimizers`.
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new
implementation switched on by default, meaning the old implementation is
deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
## jax 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15).
* Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
* Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`).
* Binary operations between JAX arrays and built-in collections (`dict`, `list`, `set`, `tuple`)
now raise a `TypeError` in all cases. Previously some cases (particularly equality and inequality)
would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}`#11234`).
* Several {mod}`jax.tree_util` routines accessed as top-level JAX package imports are now
deprecated, and will be removed in a future JAX release in accordance with the
{ref}`api-compatibility` policy:
* {func}`jax.treedef_is_leaf` is deprecated in favor of {func}`jax.tree_util.treedef_is_leaf`
* {func}`jax.tree_flatten` is deprecated in favor of {func}`jax.tree_util.tree_flatten`
* {func}`jax.tree_leaves` is deprecated in favor of {func}`jax.tree_util.tree_leaves`
* {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure`
* {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose`
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
following a similar deprecation in {func}`scipy.linalg.solve`.
2022-06-21 12:41:26 -07:00
## jaxlib 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
2022-06-21 12:41:26 -07:00
## jax 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
* `JAX_PLATFORMS` now raises an exception when platform initialization fails.
* Changes
* Fixed compatibility problems with NumPy 1.23.
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
an implementation based on QR decomposition.
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
* `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when
used on jax arrays ({jax-issue}`#10659`). In particular:
- `pickle` and `deepcopy` previously returned `np.ndarray` objects when used
on a `DeviceArray`; now `DeviceArray` objects are returned. For `deepcopy`,
the copied array is on the same device as the original. For `pickle` the
deserialized array will be on the default device.
- Within function transformations (i.e. traced code), `deepcopy` and `copy`
previously were no-ops. Now they use the same mechanism as `DeviceArray.copy()`.
2022-06-03 17:20:09 -07:00
- Calling `pickle` on a traced array now results in an explicit
`ConcretizationTypeError`.
* The implementation of singular value decomposition (SVD) and
symmetric/Hermitian eigendecomposition should be significantly faster on
TPU, especially for matrices above 1000x1000 or so. Both now use a spectral
divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
* {func}`jax.numpy.ldexp` no longer silently promotes all inputs to float64,
instead it promotes to float32 for integer inputs of size int32 or smaller
({jax-issue}`#10921`).
* Add a `create_perfetto_link` option to {func}`jax.profiler.start_trace` and
{func}`jax.profiler.start_trace`. When used, the profiler will generate a
link to the Perfetto UI to view the trace.
* Changed the semantics of {func}`jax.profiler.start_server(...)` to store the
keepalive globally, rather than requiring the user to keep a reference to
it.
* Added {func}`jax.random.generalized_normal`.
* Added {func}`jax.random.ball`.
2022-06-03 17:20:09 -07:00
* Added {func}`jax.default_device`.
2022-06-02 22:15:53 -07:00
* Added a `python -m jax.collect_profile` script to manually capture program
traces as an alternative to the TensorBoard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
dtype casts are deprecated, and now result in a `FutureWarning`.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.
2022-06-11 14:24:19 -04:00
* Added {func}`jax.scipy.stats.gennorm`.
* {func}`jax.numpy.roots` is now better behaved when `strip_zeros=False` when
coefficients have leading zeros ({jax-issue}`#11215`).
## jaxlib 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14
was released in 2018, so this should not be a very onerous requirement.
* The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
* The Python flatbuffers package is no longer a dependency of jaxlib.
## jax 0.3.13 (May 16, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13).
## jax 0.3.12 (May 15, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12).
* Changes
* Fixes [#10717](https://github.com/jax-ml/jax/issues/10717).
## jax 0.3.11 (May 15, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11).
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
* Deprecations
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
keyword-only. As a backward-compatibility step passing keyword-only
arguments positionally yields a warning, but in a future JAX release passing
keyword-only arguments positionally will fail.
However, most users should prefer to use {mod}`jax.numpy.linalg` instead.
* {func}`jax.scipy.linalg.polar_unitary`, which was a JAX extension to the
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jax 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10).
## jaxlib 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* Changes
* [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a)
fixes an issue in the MHLO canonicalizer that caused constant folding to
take a long time or crash for certain programs.
2022-04-29 20:15:52 -07:00
## jax 0.3.9 (May 2, 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9).
* Changes
* Added support for fully asynchronous checkpointing for GlobalDeviceArray.
2022-04-29 20:15:52 -07:00
## jax 0.3.8 (April 29 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
* {func}`jax.scipy.cluster.vq.vq` has been added.
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`)
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`.
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into
range for the gradient, which was not mathematically correct.
* {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers.
* {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument
is not of an integer type, matching the behavior of
{func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently
cast to integers.
* {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument
is not of an integer type, matching the behavior of
{func}`numpy.split`. Previously non-integer `axis` was silently
cast to integers.
* {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions
are not of an integer type, matching the behavior of
{func}`numpy.indices`. Previously non-integer dimensions were silently
cast to integers.
* {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument
is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers.
2022-04-29 14:20:50 -04:00
* Added {func}`jax.random.orthogonal`.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
`_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`,
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
Most of these utilities can be replaced by calls to standard python & numpy testing utilities found
in e.g. {mod}`unittest`, {mod}`absl.testing`, {mod}`numpy.testing`, etc. JAX-specific functionality
such as device checking can be replaced through the use of public APIs such as {func}`jax.devices`.
Many of the deprecated utilities will still exist in {mod}`jax._src.test_util`, but these are not
public APIs and as such may be changed or removed without notice in future releases.
2022-04-13 12:11:10 -07:00
2022-04-15 12:02:05 -04:00
## jax 0.3.7 (April 15, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7).
2022-04-15 12:02:05 -04:00
* Changes:
* Fixed a performance problem if the indices passed to
{func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`).
* {func}`jax.scipy.special.expit` and {func}`jax.scipy.special.logit` now
require their arguments to be scalars or JAX arrays. They also now promote
integer arguments to floating point.
* The `DeviceArray.tile()` method is deprecated, because numpy arrays do not have a
`tile()` method. As a replacement for this, use {func}`jax.numpy.tile`
({jax-issue}`#10266`).
2022-04-18 08:17:14 -04:00
## jaxlib 0.3.7 (April 15, 2022)
* Changes:
* Linux wheels are now built conforming to the `manylinux2014` standard, instead
of `manylinux2010`.
## jax 0.3.6 (April 12, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6).
* Changes:
2022-04-15 12:02:05 -04:00
* Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU
pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218).
* Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API.
## jax 0.3.5 (April 7, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`).
2022-03-30 08:05:34 -07:00
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`)
2022-03-27 12:31:12 +01:00
* added {func}`jax.scipy.linalg.rsf2csf`
* `jax.experimental.sharded_jit` has been deprecated and will be removed soon.
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
2022-04-01 14:51:54 -07:00
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
2022-04-06 23:56:41 +00:00
* `jax.experimental.sharded_jit` is deprecated. Use `pjit` instead.
2022-03-17 15:35:26 -07:00
## jaxlib 0.3.5 (April 7, 2022)
* Bug fixes
* Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}`#9946`).
2022-04-06 23:56:41 +00:00
* Fixed incorrect constant-folding of complex scatters ({jax-issue}`#10159`)
2022-03-17 15:35:26 -07:00
2022-03-18 14:16:00 -07:00
## jax 0.3.4 (March 18, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4).
2022-03-18 14:16:00 -07:00
2022-03-17 15:35:26 -07:00
## jax 0.3.3 (March 17, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3).
## jax 0.3.2 (March 16, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2).
* Changes:
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`.
* Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are
optimized alternatives to `jax.lax.top_k`.
* {func}`jax.numpy.broadcast_arrays` and {func}`jax.numpy.broadcast_to` now require scalar
or array-like inputs, and will fail if they are passed lists (part of {jax-issue}`#7737`).
* The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
* `pjit` now works on CPU (in addition to previous TPU and GPU support).
## jaxlib 0.3.2 (March 16, 2022)
* Changes
* ``XlaComputation.as_hlo_text()`` now supports printing large constants by
passing boolean flag ``print_large_constants=True``.
* Deprecations:
* The ``.block_host_until_ready()`` method on JAX arrays has been deprecated.
Use ``.block_until_ready()`` instead.
## jax 0.3.1 (Feb 18, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1).
* Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
The suggested replacement is to use `parametrized.TestCase` directly. For tests that
rely on custom asserts such as `JaxTestCase.assertAllClose()`, the suggested replacement
is to use standard numpy testing utilities such as {func}`numpy.testing.assert_allclose()`,
which work directly with JAX arrays ({jax-issue}`#9620`).
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by default
({jax-issue}`#9562`). To recover the previous behavior, use the new
`jax.test_util.with_config` decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* Added {func}`jax.scipy.linalg.schur`, {func}`jax.scipy.linalg.sqrtm`,
{func}`jax.scipy.signal.csd`, {func}`jax.scipy.signal.stft`,
{func}`jax.scipy.signal.welch`.
## jax 0.3.0 (Feb 10, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0).
* Changes
* jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
for the explanation.
## jaxlib 0.3.0 (Feb 10, 2022)
* Changes
* Bazel 5.0.0 is now required to build jaxlib.
* jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
for the explanation.
## jax 0.2.28 (Feb 1, 2022)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.
2022-01-18 11:38:32 -08:00
2022-01-26 22:19:11 +00:00
## jaxlib 0.1.76 (Jan 27, 2022)
2022-01-18 11:38:32 -08:00
* New features
* Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
2022-01-26 22:19:11 +00:00
* With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
2022-01-18 11:38:32 -08:00
* Breaking changes
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
2022-01-26 22:19:11 +00:00
* Bug fixes
* Fixed a bug where apparently identical pytreedef objects constructed by different routes
do not compare as equal (#9066).
* The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
2022-01-18 11:38:32 -08:00
## jax 0.2.27 (Jan 18 2022)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27).
* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent,
where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN``
representations are now treated as equivalent and sorted to the end of the array.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`#9178`).
2022-01-13 15:54:07 -08:00
* {func}`jax.numpy.unique` now treats ``NaN`` values in the same way as `np.unique` in
NumPy versions 1.21 and newer: at most one ``NaN`` value will appear in the uniquified
output ({jax-issue}`9184`).
* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).
* New features:
* add `jax.block_until_ready` ({jax-issue}`#8941)
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path.
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
* jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}`#9189`).
## jaxlib 0.1.75 (Dec 8, 2021)
* New features:
* Support for python 3.10.
2021-11-17 11:43:19 -05:00
## jax 0.2.26 (Dec 8, 2021)
2021-11-10 13:41:25 -08:00
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26).
* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
`FILL_OR_DROP` semantics, as documented. This primarily affects the
reverse-mode derivative, where gradients corresponding to out-of-bounds
indices will now be returned as 0. (#8634).
* jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions ({jax-issue}`#7839`).
2021-11-10 13:41:25 -08:00
2021-11-17 11:43:19 -05:00
## jaxlib 0.1.74 (Nov 17, 2021)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
the host, which is usually slower.
* Added experimental MLIR Python bindings for use by JAX.
2021-11-10 13:41:25 -08:00
## jax 0.2.25 (Nov 10, 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25).
* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* `jax.random.permutation` supports new `independent` keyword argument
({jax-issue}`#8430`)
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
* New features:
* Added `jax.lax.linalg.qdwh`.
## jax 0.2.24 (Oct 19, 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24).
2021-10-13 17:11:50 -06:00
* New features:
* `jax.random.choice` and `jax.random.permutation` now support
multidimensional arrays and an optional `axis` argument ({jax-issue}`#8158`)
* Breaking changes:
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
(see {jax-issue}`#7737`)
2021-10-13 17:11:50 -06:00
## jaxlib 0.1.73 (Oct 18, 2021)
* Multiple cuDNN versions are now supported for jaxlib GPU `cuda11` wheels.
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.
* Breaking changes:
* The install commands for GPU jaxlib are as follows:
```bash
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
## jax 0.2.22 (Oct 12, 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.
Unhashable static arguments have long been disallowed on `jax.jit`, but they
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
arguments using object identity.
This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
wants to compare static arguments by object identity, they can define
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use `functools.partial` to encapsulate the
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
* Deprecations
* The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are
deprecated and will be removed in a future JAX release. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a
`DeprecationWarning`.
Enable C++ pmap. On CPU: ``` name old cpu/op new cpu/op delta pmap_trivial_2_devices 128µs ± 6% 14µs ± 3% -89.06% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 212µs ± 2% 35µs ± 1% -83.54% (p=0.008 n=5+5) pmap_trivial_8_devices 215µs ± 1% 40µs ± 4% -81.31% (p=0.008 n=5+5) pmap_simple_2_devices 123µs ± 5% 15µs ± 6% -87.70% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 211µs ± 3% 35µs ± 2% -83.24% (p=0.008 n=5+5) pmap_simple_8_devices 217µs ± 5% 40µs ± 2% -81.68% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices_100_args 5.42ms ± 7% 0.52ms ± 2% -90.44% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.5ms ±21% 17.5ms ±37% -34.18% (p=0.008 n=5+5) sda_index_1 7.45µs ± 6% 7.53µs ± 6% ~ (p=0.222 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) name old time/op new time/op delta pmap_trivial_2_devices 136µs ± 8% 19µs ± 3% -86.08% (p=0.008 n=5+5) pmap_trivial_dispatch_8_devices 216µs ± 3% 39µs ± 2% -81.94% (p=0.008 n=5+5) pmap_trivial_8_devices 219µs ± 2% 49µs ±38% -77.67% (p=0.008 n=5+5) pmap_simple_2_devices 130µs ± 5% 20µs ± 5% -84.38% (p=0.008 n=5+5) pmap_simple_dispatch_8_devices 216µs ± 3% 39µs ± 5% -81.71% (p=0.008 n=5+5) pmap_simple_8_devices 221µs ± 6% 43µs ± 1% -80.41% (p=0.016 n=5+4) pmap_simple_dispatch_8_devices_100_args 5.52ms ± 7% 0.59ms ± 2% -89.28% (p=0.008 n=5+5) pmap_simple_8_devices_100_args 26.6ms ±21% 17.6ms ±37% -34.04% (p=0.008 n=5+5) sda_index_1 7.48µs ± 8% 7.53µs ± 6% ~ (p=0.310 n=5+5) sda_index_2 14.1µs ± 1% 14.3µs ± 4% ~ (p=0.690 n=5+5) sda_index_8 56.0µs ± 3% 56.9µs ± 4% ~ (p=0.310 n=5+5) ``` PiperOrigin-RevId: 401274089
2021-10-06 10:07:41 -07:00
* New features:
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`)
## jaxlib 0.1.72 (Oct 12, 2021)
* Breaking changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
* Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.
## jax 0.2.21 (Sept 23, 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`#7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`#7747` {jax-issue}`#7802` {jax-issue}`#7907`).
See {jax-issue}`#7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
2021-09-27 11:38:36 -07:00
* `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
now return `False` ({jax-issue}`7927`).
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`#7936`).
2021-09-02 15:38:47 -07:00
## jax 0.2.20 (Sept 2, 2021)
2021-08-12 21:17:53 -07:00
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
2021-08-26 11:12:16 -07:00
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
({jax-issue}`#7662`)
2021-08-12 21:17:53 -07:00
## jaxlib 0.1.71 (Sep 1, 2021)
* Breaking changes:
* Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 10.2 and CUDA 11.1+.
2021-08-12 21:17:53 -07:00
## jax 0.2.19 (Aug 12, 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
2021-07-29 09:18:01 -04:00
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The `jit` decorator has been added around the implementation of a number of
operators on JAX arrays. This speeds up dispatch times for common
operators such as `+`.
This change should largely be transparent to most users. However, there is
one known behavioral change, which is that large integer constants may now
produce an error when passed directly to a JAX operator
(e.g., `x + 2**40`). The workaround is to cast the constant to an
explicit type (e.g., `np.float64(2**40)`).
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
({jax-issue}`#7317`)
2021-08-12 21:17:53 -07:00
* Bug fixes:
* Some leaked trace errors from the previous release ({jax-issue}`#7613`)
## jaxlib 0.1.70 (Aug 9, 2021)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
2021-07-29 09:18:01 -04:00
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.
## jax 0.2.18 (July 21 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.
* New features:
* Added a polar decomposition ({py:func}`jax.scipy.linalg.polar`).
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
2021-08-02 17:57:09 -07:00
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
2021-07-09 21:16:48 -07:00
## jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.
## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
problem.
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
AWN-enabled reduction over named axes in reverse-mode AD Previously, reverse-mode AD operators inside JAX maps always meant "compute a gradient (or VJP, etc.) for each axis index in the map". For instance, `vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`. In batching tracer terms, this "elementwise" behavior means that, if any inputs to a function being transposed are mapped, the cotangents of all inputs, even unmapped ones, would also be mapped. But a user might want them to be unmapped (if, for instance, they're interested in a total gradient rather than a per-example gradient). They could always reduce (`psum`) the cotangents afterwards, but computing mapped cotangents in the first place would likely be an unacceptable waste of memory and can't necessarily be optimized away. If we want to fuse these reductions into reverse-mode autodiff itself, we need the backward_pass logic and/or transpose rules to know about whether primal values are mapped or unmapped. This is made possible by avals-with-names, which encodes that information in the avals of the primal jaxpr. Putting things together, **this change adds an option to reverse-mode AD APIs that indicates which named axes should be reduced over in the backward pass in situations where they were broadcasted over in the forward pass**. All other named axes will be treated in the current elementwise way. This has the effect of making APIs like `grad` behave akin to collectives like `psum`: they act collectively over axes that are named explicitly, and elementwise otherwise. Since avals-with-names is currently enabled only in `xmap`, this behavior is only available in that context for now. It's also missing some optimizations: - reductions aren't fused into any first-order primitives (e.g. a `pdot` should have a named contracting axis added rather than being followed by a `psum`; this can be implemented by putting these primitives into `reducing_transposes`) - reductions are performed eagerly, even over axes that are mapped to hardware resources (the optimal thing to do would be to reduce eagerly over any vectorized axis component while delaying the reduction over any hardware-mapped component until the end of the overall backward pass; this would require a way to represent these partially-reduced values) PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).
2021-06-23 14:42:36 -07:00
2021-06-23 14:42:36 -07:00
## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16).
2021-06-23 11:55:40 -07:00
## jax 0.2.15 (June 23 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
* [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.
2021-06-10 13:15:53 -07:00
* Breaking changes:
2021-06-10 12:12:13 -04:00
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
2021-06-10 13:15:53 -07:00
* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).
2021-06-23 11:55:40 -07:00
## jaxlib 0.1.68 (June 23 2021)
* Bug fixes:
* Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to
CPU.
2021-06-10 13:15:53 -07:00
## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
2021-06-10 13:15:53 -07:00
({jax-issue}`#6827`).
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
* The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
* New SciPy function {py:func}`jax.scipy.special.lpmn`.
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
as JAX ({jax-issue}`#6883`).
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
properly to apply only during the just-in-time conversion
({jax-issue}`#6720`).
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`#6717`).
* The {func}`jax2tf.convert` now has support for inequality comparisons and
2021-06-10 13:15:53 -07:00
min/max for complex numbers ({jax-issue}`#6892`).
## jaxlib 0.1.67 (May 17 2021)
## jaxlib 0.1.66 (May 11 2021)
* New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with
CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
the CUDA 11.1 wheel for those versions (cuda111).
2021-05-04 09:28:04 -04:00
* Jaxlib now bundles `libdevice.10.bc` in CUDA wheels. There should be no need
to point JAX to a CUDA installation to find this file.
* Added automatic support for static keyword arguments to the {func}`jit`
implementation.
* Added support for pretransformation exception traces.
* Initial support for pruning unused arguments from {func}`jit` -transformed
computations.
Pruning is still a work in progress.
* Improved the string representation of {class}`PyTreeDef` objects.
* Added support for XLA's variadic ReduceWindow.
* Bug fixes:
* Fixed a bug in the remote cloud TPU support when large numbers of arguments
are passed to a computation.
* Fix a bug that meant that JAX garbage collection was not triggered by
{func}`jit` transformed functions.
## jax 0.2.13 (May 3 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`#6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
2021-05-03 11:27:07 -04:00
* Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
2021-05-04 09:28:04 -04:00
* The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an `UnfilteredStackTrace` exception
containing the original trace as the `__cause__` of the filtered exception.
Filtered stack traces now also work with Python 3.6.
* If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a `__cause__` of
the exception a `JaxStackTraceBeforeTransformation` object that contains the
stack trace that created the original operation in the forward pass.
Requires jaxlib 0.1.66.
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`.
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`#6360`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`#6572`).
2021-04-01 10:11:52 -07:00
## jaxlib 0.1.65 (April 7 2021)
2021-04-01 10:11:52 -07:00
## jax 0.2.12 (April 1 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
2021-03-25 16:44:58 -07:00
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
`OverflowError` rather than having their value silently truncated.
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`#6262`).
2021-04-01 09:29:22 -07:00
* {func}`jax.random.randint` clips rather than wraps of out-of-bounds limits, and can now generate
integers in the full range of the specified dtype ({jax-issue}`#5868`)
2021-03-23 19:33:04 -07:00
## jax 0.2.11 (March 23 2021)
* [GitHub
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11).
2021-03-05 11:07:50 -08:00
* New features:
* [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers:
2021-03-23 19:33:04 -07:00
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete`
2021-03-05 11:07:50 -08:00
* Bug fixes:
* [#6136](https://github.com/jax-ml/jax/pull/6136) generalized
2021-03-23 19:33:04 -07:00
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling
2021-03-23 19:33:04 -07:00
some constants like `enum.IntEnums`
* [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with
2021-03-23 19:33:04 -07:00
incomplete beta functions
* [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during
2021-03-23 19:33:04 -07:00
tracing
* [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when
2021-03-23 19:33:04 -07:00
converting some large Python integers to floats
2021-03-05 11:07:50 -08:00
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.
2021-03-05 11:07:50 -08:00
## jaxlib 0.1.64 (March 18 2021)
## jaxlib 0.1.63 (March 17 2021)
2021-03-09 17:55:40 -08:00
2021-03-05 11:07:50 -08:00
## jax 0.2.10 (March 5 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10).
2021-03-05 11:07:50 -08:00
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`#5627`)
and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
2021-03-05 11:07:50 -08:00
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
* Breaking changes:
* JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression `jnp.bfloat16(1) + 0.1 * jnp.arange(10)`
previously returned a `float64` array, and now returns a `bfloat16` array.
JAX's type promotion behavior is described at {ref}`type-promotion`.
* {func}`jax.numpy.linspace` now computes the floor of integer values, i.e.,
rounding towards -inf rather than 0. This change was made to match NumPy
1.20.0.
* {func}`jax.numpy.i0` no longer accepts complex numbers. Previously the
function computed the absolute value of complex arguments. This change was
made to match the semantics of NumPy 1.20.0.
* Several {mod}`jax.numpy` functions no longer accept tuples or lists in place
of array arguments: {func}`jax.numpy.pad`, :func`jax.numpy.ravel`,
{func}`jax.numpy.repeat`, {func}`jax.numpy.reshape`.
In general, {mod}`jax.numpy` functions should be used with scalars or array arguments.
2021-03-09 17:55:40 -08:00
## jaxlib 0.1.62 (March 9 2021)
2021-03-05 11:07:50 -08:00
* New features:
* jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the `--target_cpu_features` flag
to `build.py`. `--target_cpu_features` also replaces
`--enable_march_native`.
## jaxlib 0.1.61 (February 12 2021)
## jaxlib 0.1.60 (February 3 2021)
2021-03-05 11:07:50 -08:00
* Bug fixes:
* Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
* `bool`, `int8`, and `uint8` are now considered safe to cast to
`bfloat16` NumPy extension type.
## jax 0.2.9 (January 26 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9).
2021-03-05 11:07:50 -08:00
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
* Add {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`.
These are context managers which allow X64 mode to be temporarily enabled/disabled
within a session.
* Breaking changes:
* {func}`jax.ops.segment_sum` now drops segment IDs that are out of range rather
than wrapping them into the segment ID space. This was done for performance
reasons.
## jaxlib 0.1.59 (January 15 2021)
## jax 0.2.8 (January 12 2021)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8).
2021-03-05 11:07:50 -08:00
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`#5244`)
* Add {func}`jax.experimental.host_callback.call` to call a custom Python
function on the host and return a result to the device computation.
({jax-issue}`#5243`)
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`#5156`)
2021-08-02 17:57:09 -07:00
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
2021-03-05 11:07:50 -08:00
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`#5182`).
* Breaking changes:
* `jax.numpy.pad` now takes keyword arguments. Positional argument `constant_values`
has been removed. In addition, passing unsupported keyword arguments raises an error.
* Changes for {func}`jax.experimental.host_callback.id_tap` ({jax-issue}`#5243`):
* Removed support for `kwargs` for {func}`jax.experimental.host_callback.id_tap`.
(This support has been deprecated for a few months.)
* Changed the printing of tuples for {func}`jax.experimental.host_callback.id_print`
to use '(' instead of '['.
* Changed the {func}`jax.experimental.host_callback.id_print` in presence of JVP
to print a pair of primal and tangent. Previously, there were two separate
print operations for the primals and the tangent.
* `host_callback.outfeed_receiver` has been removed (it is not necessary,
and was deprecated a few months ago).
* New features:
* New flag for debugging `inf`, analogous to that for `NaN` ({jax-issue}`#5224`).
2021-03-05 11:07:50 -08:00
## jax 0.2.7 (Dec 4 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7).
2021-03-05 11:07:50 -08:00
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
2021-08-02 17:57:09 -07:00
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
2021-03-05 11:07:50 -08:00
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
* Bug fixes:
* Fix higher-than-second order derivatives of `jax.numpy.sinc` at zero
* Fix some hard-to-hit bugs around symbolic zeros in transpose rules
* Breaking changes:
* `jax.experimental.optix` has been deleted, in favor of the standalone
`optax` Python package.
* indexing of JAX arrays with non-tuple sequences now raises a `TypeError`. This type of indexing
has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
See {jax-issue}`#4564`.
## jax 0.2.6 (Nov 18 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6).
2021-03-05 11:07:50 -08:00
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
2021-03-05 11:07:50 -08:00
* Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42).
2021-03-05 11:07:50 -08:00
* Improve consistency of type promotion behavior ({jax-issue}`#4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
it returned `complex128`.
* Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
`jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
`jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
the first returned `float64` and the second returned `float16`.
* The contents of the (undocumented) `jax.lax_linalg` linear algebra module
are now exposed publicly as `jax.lax.linalg`.
* `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
({jax-issue}`#4877`).
This required changing the result for a given seed in a few particular cases:
* With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
`[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
* Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
rather than a `TypeError`. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with `jax_enable_x64=False`
outside JIT, you can use:
```
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
```
* DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
to access its value while it has been deleted.
## jaxlib 0.1.58 (January 12ish 2021)
* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
`np.cint`) instead of standard types (e.g., `np.int32`). (#4903)
* Fixed a crash when constant-folding certain int16 operations. (#4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.
## jaxlib 0.1.57 (November 12 2020)
* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (#4651).
* Add support for retaining ownership when passing arrays to DLPack (#4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
* Fixed a bug in profiler where tools are missing (#4427).
* Dropped support for CUDA 10.0.
## jax 0.2.5 (October 27 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5).
2021-03-05 11:07:50 -08:00
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`.
* Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
2021-03-05 11:07:50 -08:00
## jax 0.2.4 (October 19 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4).
2021-03-05 11:07:50 -08:00
* Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`.
* Deprecations
* Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}`#4564`.
## jaxlib 0.1.56 (October 14, 2020)
## jax 0.2.3 (October 14 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3).
2021-03-05 11:07:50 -08:00
* The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation
## jax 0.2.2 (October 13 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2).
2021-03-05 11:07:50 -08:00
## jax 0.2.1 (October 6 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1).
2021-03-05 11:07:50 -08:00
* Improvements:
* As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
{py:func}`jax.experimental.host_callback.id_tap` is not used in the computation.
## jax (0.2.0) (September 23 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0).
2021-03-05 11:07:50 -08:00
* Improvements:
2021-03-25 16:44:58 -07:00
* Omnistaging on by default. See {jax-issue}`#3370` and
[omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
2021-03-05 11:07:50 -08:00
## jax (0.1.77) (September 15 2020)
* Breaking changes:
* New simplified interface for {py:func}`jax.experimental.host_callback.id_tap` (#4101)
## jaxlib 0.1.55 (September 8, 2020)
* Update XLA:
* Fix bug in DLPackManagedTensorToBuffer (#4196)
## jax 0.1.76 (September 8, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76).
2021-03-05 11:07:50 -08:00
## jax 0.1.75 (July 30, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75).
2021-03-05 11:07:50 -08:00
* Bug Fixes:
* make jnp.abs() work for unsigned inputs (#3914)
* Improvements:
* "Omnistaging" behavior added behind a flag, disabled by default (#3370)
## jax 0.1.74 (July 29, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74).
2021-03-05 11:07:50 -08:00
* New Features:
* BFGS (#3101)
2021-08-02 17:57:09 -07:00
* TPU support for half-precision arithmetic (#3878)
2021-03-05 11:07:50 -08:00
* Bug Fixes:
* Prevent some accidental dtype warnings (#3874)
* Fix a multi-threading bug in custom derivatives (#3845, #3869)
* Improvements:
* Faster searchsorted implementation (#3873)
* Better test coverage for jax.numpy sorting algorithms (#3836)
## jaxlib 0.1.52 (July 22, 2020)
* Update XLA.
## jax 0.1.73 (July 22, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73).
2021-03-05 11:07:50 -08:00
* The minimum jaxlib version is now 0.1.51.
* New Features:
* jax.image.resize. (#3703)
* hfft and ihfft (#3664)
* jax.numpy.intersect1d (#3726)
* jax.numpy.lexsort (#3812)
* `lax.scan` and the `scan` primitive support an `unroll`
parameter for loop unrolling when lowering to XLA
({jax-issue}`#3738`).
* Bug Fixes:
* Fix reduction repeated axis error (#3618)
* Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
* make psum transpose handle zero cotangents (#3653)
* Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
* Support differentiation through jax.lax.all_to_all (#3733)
* address nan issue in jax.scipy.special.zeta (#3777)
* Improvements:
* Many improvements to jax2tf
* Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
* Enable XLA SPMD partitioning by default. (#3151)
* Add support for 0d transpose convolution (#3643)
* Make LU gradient work for low-rank matrices (#3610)
* support multiple_results and custom JVPs in jet (#3657)
* Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
* Implement complex convolutions on CPU and GPU. (#3735)
* Make jnp.take work for empty slices of empty arrays. (#3751)
* Relax dimension ordering rules for dot_general. (#3778)
* Enable buffer donation for GPU. (#3800)
* Add support for base dilation and window dilation to reduce window op… (#3803)
## jaxlib 0.1.51 (July 2, 2020)
* Update XLA.
* Add new runtime support for host_callback.
## jax 0.1.72 (June 28, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72).
2021-03-05 11:07:50 -08:00
* Bug fixes:
* Fix an odeint bug introduced in the previous release, see
{jax-issue}`#3587`.
## jax 0.1.71 (June 25, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71).
2021-03-05 11:07:50 -08:00
* The minimum jaxlib version is now 0.1.48.
* Bug fixes:
* Allow `jax.experimental.ode.odeint` dynamics functions to close over
values with respect to which we're differentiating
{jax-issue}`#3562`.
## jaxlib 0.1.50 (June 25, 2020)
* Add support for CUDA 11.0.
* Drop support for CUDA 9.2 (we only maintain support for the last four CUDA
versions.)
* Update XLA.
## jaxlib 0.1.49 (June 19, 2020)
* Bug fixes:
* Fix build issue that could result in slow compiles
(<https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1>)
## jaxlib 0.1.48 (June 12, 2020)
* New features:
* Adds support for fast traceback collection.
* Adds preliminary support for on-device heap profiling.
* Implements `np.nextafter` for `bfloat16` types.
* Complex128 support for FFTs on CPU and GPU.
* Bug fixes:
2021-03-05 11:07:50 -08:00
* Improved float64 `tanh` accuracy on GPU.
* float64 scatters on GPU are much faster.
* Complex matrix multiplication on CPU should be much faster.
* Stable sorts on CPU should actually be stable now.
* Concurrency bug fix in CPU backend.
## jax 0.1.70 (June 8, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70).
2021-03-05 11:07:50 -08:00
* New features:
* `lax.switch` introduces indexed conditionals with multiple
branches, together with a generalization of the `cond`
primitive
{jax-issue}`#3318`.
## jax 0.1.69 (June 3, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69).
2021-03-05 11:07:50 -08:00
## jax 0.1.68 (May 21, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68).
2021-03-05 11:07:50 -08:00
* New features:
* {func}`lax.cond` supports a single-operand form, taken as the argument
to both branches
{jax-issue}`#2993`.
* Notable changes:
* The format of the `transforms` keyword for the {func}`jax.experimental.host_callback.id_tap`
primitive has changed {jax-issue}`#3132`.
## jax 0.1.67 (May 12, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67).
2021-03-05 11:07:50 -08:00
* New features:
* Support for reduction over subsets of a pmapped axis using `axis_index_groups`
{jax-issue}`#2382`.
* Experimental support for printing and calling host-side Python function from
compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)
({jax-issue}`#3006`).
* Notable changes:
* The visibility of names exported from {mod}`jax.numpy` has been
tightened. This may break code that was making use of names that were
previously exported accidentally.
## jaxlib 0.1.47 (May 8, 2020)
* Fixes crash for outfeed.
## jax 0.1.66 (May 5, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66).
2021-03-05 11:07:50 -08:00
* New features:
* Support for `in_axes=None` on {func}`pmap`
{jax-issue}`#2896`.
## jaxlib 0.1.46 (May 5, 2020)
* Fixes crash for linear algebra functions on Mac OS X (#432).
* Fixes an illegal instruction crash caused by using AVX512 instructions when
an operating system or hypervisor disabled them (#2906).
## jax 0.1.65 (April 30, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65).
2021-03-05 11:07:50 -08:00
* New features:
* Differentiation of determinants of singular matrices
{jax-issue}`#2809`.
* Bug fixes:
* Fix {func}`odeint` differentiation with respect to time of ODEs with
time-dependent dynamics {jax-issue}`#2817`,
also add ODE CI testing.
* Fix {func}`lax_linalg.qr` differentiation
{jax-issue}`#2867`.
## jaxlib 0.1.45 (April 21, 2020)
* Fixes segfault: {jax-issue}`#2755`
* Plumb is_stable option on Sort HLO through to Python.
## jax 0.1.64 (April 21, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64).
2021-03-05 11:07:50 -08:00
* New features:
* Add syntactic sugar for functional indexed updates
{jax-issue}`#2684`.
* Add {func}`jax.numpy.linalg.multi_dot` {jax-issue}`#2726`.
* Add {func}`jax.numpy.unique` {jax-issue}`#2760`.
* Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
* Add {func}`jax.numpy.rint` {jax-issue}`#2724`.
* Add more primitive rules for {func}`jax.experimental.jet`.
* Bug fixes:
* Fix {func}`logaddexp` and {func}`logaddexp2` differentiation at zero {jax-issue}`#2107`.
* Improve memory usage in reverse-mode autodiff without {func}`jit`
{jax-issue}`#2719`.
* Better errors:
* Improves error message for reverse-mode differentiation of {func}`lax.while_loop`
{jax-issue}`#2129`.
## jaxlib 0.1.44 (April 16, 2020)
* Fixes a bug where if multiple GPUs of different models were present, JAX
would only compile programs suitable for the first GPU.
* Bugfix for `batch_group_count` convolutions.
* Added precompiled SASS for more GPU versions to avoid startup PTX compilation
hang.
## jax 0.1.63 (April 12, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63).
2021-03-05 11:07:50 -08:00
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
* Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`.
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`.
2021-03-05 11:07:50 -08:00
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
* Add `callback_transform` {jax-issue}`#2665`.
* Implement `rollaxis`, `convolve`/`correlate` 1d & 2d, `copysign`,
`trunc`, `roots`, and `quantile`/`percentile` interpolation options.
## jaxlib 0.1.43 (March 31, 2020)
* Fixed a performance regression for Resnet-50 on GPU.
## jax 0.1.62 (March 21, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62).
2021-03-05 11:07:50 -08:00
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the
convention `0. * nan == 0.`. This change means some programs when
differentiated will produce nans when they previously produced correct
values, though it ensures nans rather than silently incorrect results are
produced for other programs. See #2447 and #1052 for details.
* Added an `all_gather` parallel convenience function.
* More type annotations in core code.
## jaxlib 0.1.42 (March 19, 2020)
* jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This
release fixes it again.
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
## jax 0.1.61 (March 17, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61).
2021-03-05 11:07:50 -08:00
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
## jax 0.1.60 (March 17, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60).
2021-03-05 11:07:50 -08:00
* New features:
* {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
the user to specify arguments that should be treated as compile-time
constants and should be broadcasted to all devices. It works analogously to
`static_argnums` in {py:func}`jax.jit`.
* Improved error messages for when tracers are mistakenly saved in global state.
* Added {py:func}`jax.nn.one_hot` utility function.
* Added {mod}`jax.experimental.jet` for exponentially faster
higher-order automatic differentiation.
* Added more correctness checking to arguments of {py:func}`jax.lax.broadcast_in_dim`.
* The minimum jaxlib version is now 0.1.41.
## jaxlib 0.1.40 (March 4, 2020)
* Adds experimental support in Jaxlib for TensorFlow profiler, which allows
tracing of CPU and GPU computations from TensorBoard.
* Includes prototype support for multihost GPU computations that communicate via
NCCL.
* Improves performance of NCCL collectives on GPU.
* Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and
RandomGamma implementations.
* Supports device assignments known at XLA compilation time.
## jax 0.1.59 (February 11, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59).
2021-03-05 11:07:50 -08:00
* Breaking changes
* The minimum jaxlib version is now 0.1.38.
* Simplified {py:class}`Jaxpr` by removing the `Jaxpr.freevars` and
`Jaxpr.bound_subjaxprs`. The call primitives (`xla_call`, `xla_pmap`,
`sharded_call`, and `remat_call`) get a new parameter `call_jaxpr` with a
fully-closed (no `constvars`) jaxpr. Also, added a new field `call_primitive`
to primitives.
* New features:
* Reverse-mode automatic differentiation (e.g. `grad`) of `lax.cond`, making it
now differentiable in both modes ({jax-issue}`#2091`)
* JAX now supports DLPack, which allows sharing CPU and GPU arrays in a
zero-copy way with other libraries, such as PyTorch.
* JAX GPU DeviceArrays now support `__cuda_array_interface__`, which is another
zero-copy protocol for sharing GPU arrays with other libraries such as CuPy
and Numba.
* JAX CPU device buffers now implement the Python buffer protocol, which allows
zero-copy buffer sharing between JAX and NumPy.
* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
## jaxlib 0.1.39 (February 11, 2020)
* Updates XLA.
## jaxlib 0.1.38 (January 29, 2020)
* CUDA 9.0 is no longer supported.
* CUDA 10.2 wheels are now built by default.
## jax 0.1.58 (January 28, 2020)
* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58).
2021-03-05 11:07:50 -08:00
* Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on
January 1, 2020. Please update to Python 3.5 or newer.
* New features
> > * Forward-mode automatic differentiation (`jvp`) of while loop
> ({jax-issue}`#1980`)
> * New NumPy and SciPy functions:
2021-03-09 17:55:40 -08:00
>
2021-03-05 11:07:50 -08:00
> * {py:func}`jax.numpy.fft.fft2`
> * {py:func}`jax.numpy.fft.ifft2`
> * {py:func}`jax.numpy.fft.rfft`
> * {py:func}`jax.numpy.fft.irfft`
> * {py:func}`jax.numpy.fft.rfft2`
> * {py:func}`jax.numpy.fft.irfft2`
> * {py:func}`jax.numpy.fft.rfftn`
> * {py:func}`jax.numpy.fft.irfftn`
> * {py:func}`jax.numpy.fft.fftfreq`
> * {py:func}`jax.numpy.fft.rfftfreq`
> * {py:func}`jax.numpy.linalg.matrix_rank`
> * {py:func}`jax.numpy.linalg.matrix_power`
> * {py:func}`jax.scipy.special.betainc`
> * Batched Cholesky decomposition on GPU now uses a more efficient batched
> kernel.
### Notable bug fixes
* With the Python 3 upgrade, JAX no longer depends on `fastcache`, which should
help with installation.