Merge pull request #16775 from froystig:random-api-policy

PiperOrigin-RevId: 549122781
This commit is contained in:
jax authors 2023-07-18 15:06:56 -07:00
commit 4b72163423
2 changed files with 34 additions and 7 deletions

View File

@ -31,7 +31,7 @@ Only public JAX APIs are covered, which includes the following modules:
* `jax.numpy`
* `jax.ops`
* `jax.profiler`
* `jax.random`
* `jax.random` (see [details below](#numerics-and-randomness))
* `jax.scipy`
* `jax.tree_util`
* `jax.test_util`
@ -48,10 +48,31 @@ prefixed with underscores, although we do not entirely comply with this yet.
* `jax.core`
* `jax.linear_util`
* `jax.lib`
* `jax.prng`
* `jax.interpreters`
* `jax.experimental`
* `jax.example_libraries`
* `jax.extend` (see [details](https://jax.readthedocs.io/en/latest/jax.extend.html))
This list is not exhaustive.
These lists are not exhaustive.
## Numerics and randomness
The *exact* values of numerical operations are not guaranteed to be
stable across JAX releases. In fact, exact numerics are not
necessarily stable at a given JAX version, across accelerator
platforms, within or without `jax.jit`, and more.
For a fixed PRNG key input, the outputs of pseudorandom functions in
`jax.random` may vary across JAX versions. The compatibility policy
applies only to the output *distribution*. For example, the expression
`jax.random.gumbel(jax.random.key(72))` may return a different value
across JAX releases, but `jax.random.gumbel` will remain a
pseudorandom generator for the Gumbel distribution.
We try to make such changes to pseudorandom values infrequently. When
they happen, the changes are announced in the changelog, but do not
follow a deprecation cycle. In some situations, JAX might expose a
transient configuration flag that reverts the new behavior, to help
users diagnose and update affected code. Such flags will last a
deprecation window's amount of time.

View File

@ -358,12 +358,18 @@ class PrngTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case, make_key):
"""
Tests values output by various distributions. This will catch any unintentional
changes to the implementations that could result in different random sequences.
Tests values output by various distributions. This will catch any
unintentional changes to the implementations that could result in
different random sequences.
Any refactoring of random distributions that leads to non-trivial differences in
this test should involve a deprecation cycle following the procedures outlined at
https://jax.readthedocs.io/en/latest/api_compatibility.html
Any refactoring of random distributions that leads to non-trivial
differences in this test should follow the procedure outlined at
https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness
This includes:
* Announcing the change in the CHANGELOG.md
* Considering adding a flag that reverts the new behavior, made
available for a deprecation window's amount of time.
"""
if config.x64_enabled and case.on_x64 == OnX64.SKIP:
self.skipTest("test produces different values when jax_enable_x64=True")