mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #16775 from froystig:random-api-policy
PiperOrigin-RevId: 549122781
This commit is contained in:
commit
4b72163423
@ -31,7 +31,7 @@ Only public JAX APIs are covered, which includes the following modules:
|
||||
* `jax.numpy`
|
||||
* `jax.ops`
|
||||
* `jax.profiler`
|
||||
* `jax.random`
|
||||
* `jax.random` (see [details below](#numerics-and-randomness))
|
||||
* `jax.scipy`
|
||||
* `jax.tree_util`
|
||||
* `jax.test_util`
|
||||
@ -48,10 +48,31 @@ prefixed with underscores, although we do not entirely comply with this yet.
|
||||
* `jax.core`
|
||||
* `jax.linear_util`
|
||||
* `jax.lib`
|
||||
* `jax.prng`
|
||||
* `jax.interpreters`
|
||||
* `jax.experimental`
|
||||
* `jax.example_libraries`
|
||||
* `jax.extend` (see [details](https://jax.readthedocs.io/en/latest/jax.extend.html))
|
||||
|
||||
This list is not exhaustive.
|
||||
|
||||
These lists are not exhaustive.
|
||||
## Numerics and randomness
|
||||
|
||||
The *exact* values of numerical operations are not guaranteed to be
|
||||
stable across JAX releases. In fact, exact numerics are not
|
||||
necessarily stable at a given JAX version, across accelerator
|
||||
platforms, within or without `jax.jit`, and more.
|
||||
|
||||
For a fixed PRNG key input, the outputs of pseudorandom functions in
|
||||
`jax.random` may vary across JAX versions. The compatibility policy
|
||||
applies only to the output *distribution*. For example, the expression
|
||||
`jax.random.gumbel(jax.random.key(72))` may return a different value
|
||||
across JAX releases, but `jax.random.gumbel` will remain a
|
||||
pseudorandom generator for the Gumbel distribution.
|
||||
|
||||
We try to make such changes to pseudorandom values infrequently. When
|
||||
they happen, the changes are announced in the changelog, but do not
|
||||
follow a deprecation cycle. In some situations, JAX might expose a
|
||||
transient configuration flag that reverts the new behavior, to help
|
||||
users diagnose and update affected code. Such flags will last a
|
||||
deprecation window's amount of time.
|
||||
|
@ -358,12 +358,18 @@ class PrngTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
|
||||
def testRandomDistributionValues(self, case, make_key):
|
||||
"""
|
||||
Tests values output by various distributions. This will catch any unintentional
|
||||
changes to the implementations that could result in different random sequences.
|
||||
Tests values output by various distributions. This will catch any
|
||||
unintentional changes to the implementations that could result in
|
||||
different random sequences.
|
||||
|
||||
Any refactoring of random distributions that leads to non-trivial differences in
|
||||
this test should involve a deprecation cycle following the procedures outlined at
|
||||
https://jax.readthedocs.io/en/latest/api_compatibility.html
|
||||
Any refactoring of random distributions that leads to non-trivial
|
||||
differences in this test should follow the procedure outlined at
|
||||
https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness
|
||||
|
||||
This includes:
|
||||
* Announcing the change in the CHANGELOG.md
|
||||
* Considering adding a flag that reverts the new behavior, made
|
||||
available for a deprecation window's amount of time.
|
||||
"""
|
||||
if config.x64_enabled and case.on_x64 == OnX64.SKIP:
|
||||
self.skipTest("test produces different values when jax_enable_x64=True")
|
||||
|
Loading…
x
Reference in New Issue
Block a user