[docs] Enable doctest for .md documentation files

The current invocation of doctest from GH actions picked up only .rst files.
We enable .md files also, and we make a few changes to ensure that the
doctest passes on the existing files.

The changes fall into several categories:
  * add a newline before the end of the code block, for doctest to
    pick up the expected output properly
  * update the expected values to match the current behavior
  * disable some doctests that raise expected exceptions, whenever
    I could not get doctest to match the exception details.
    Sometimes +IGNORE_EXCEPTION_DETAIL was enough, and other times
    I had to use +SKIP.
This commit is contained in:
George Necula 2024-06-03 11:27:34 +03:00
parent e81c82605f
commit 0223e830c1
6 changed files with 49 additions and 27 deletions

View File

@ -139,7 +139,7 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py

View File

@ -45,12 +45,12 @@ way. An example:
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.constant dense<2> : tensor<i32>
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
%2 = stablehlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
return %1 : tensor<i32>
}
}
@ -62,7 +62,8 @@ module @jit_f.0 {
>>> # Execute the compiled function!
>>> compiled(x, y)
DeviceArray(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
```
See the {mod}`jax.stages` documentation for more details on what functionality
@ -83,7 +84,8 @@ that have `shape` and `dtype` attributes:
```python
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
DeviceArray(10, dtype=int32)
Array(10, dtype=int32)
```
More generally, `lower` only needs its arguments to structurally supply what JAX
@ -97,18 +99,21 @@ lowering raises an error:
```python
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]
```
Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time
@ -127,15 +132,16 @@ to invoke the resulting compiled function. Continuing with our example above:
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.constant dense<14> : tensor<i32>
%1 = stablehlo.add %0, %arg0 : tensor<i32>
return %1 : tensor<i32>
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<14> : tensor<i32>
%0 = stablehlo.add %c, %arg0 : tensor<i32>
return %0 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
DeviceArray(19, dtype=int32)
Array(19, dtype=int32, weak_type=True)
```
Note that `lower` here takes two arguments as usual, but the subsequent compiled
@ -149,11 +155,13 @@ shape/dtype structure, it is necessary that the static first argument be a
concrete value. Otherwise, lowering would err:
```python
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
DeviceArray(25, dtype=int32)
Array(25, dtype=int32)
```
## AOT-compiled functions cannot be transformed
@ -179,13 +187,15 @@ in transformations. Example:
>>> g_aot = jax.jit(g).lower(z).compile()
>>> jax.vmap(g_jit)(zs)
DeviceArray([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)
Array([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)
>>> jax.vmap(g_aot)(zs)
>>> jax.vmap(g_aot)(zs) # doctest: +SKIP
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
```
A similar error is raised when `g_aot` is involved in autodiff

View File

@ -21,6 +21,7 @@ Array([0, 0], dtype=uint32)
(2,)
>>> key.dtype
dtype('uint32')
```
Starting now, new-style RNG keys can be created with
{func}`jax.random.key`:
@ -33,6 +34,7 @@ Array((), dtype=key<fry>) overlaying:
()
>>> key.dtype
key<fry>
```
This (scalar-shaped) array behaves the same as any other JAX array, except
that its element type is a key (and associated metadata). We can make
@ -48,6 +50,7 @@ Array((4,), dtype=key<fry>) overlaying:
[0 3]]
>>> key_arr.shape
(4,)
```
Aside from switching to a new constructor, most PRNG-related code should
continue to work as expected. You can continue to use keys in
@ -62,14 +65,17 @@ data = jax.random.uniform(key, shape=(5,))
However, not all numerical operations work on key arrays. They now
intentionally raise errors:
```python
>>> key = key + 1
ValueError: dtype=key<fry> is not a valid dtype for JAX type promotion.
>>> key = key + 1 # doctest: +SKIP
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.
```
If for some reason you need to recover the underlying buffer
(the old-style key), you can do so with {func}`jax.random.key_data`:
```python
>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)
```
For old-style keys, {func}`~jax.random.key_data` is an identity operation.
@ -108,6 +114,7 @@ True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
```
### Type annotations for PRNG Keys
@ -173,6 +180,7 @@ Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
```
### Safe PRNG key use
@ -322,6 +330,7 @@ which has the following property:
```python
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True
```
PRNG key arrays then have a dtype with the following properties:
```python
@ -330,6 +339,7 @@ PRNG key arrays then have a dtype with the following properties:
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True
```
And in addition to `key.dtype._rules` as outlined for extended dtypes in
general, PRNG dtypes define `key.dtype._impl`, which contains the metadata

View File

@ -2223,6 +2223,7 @@
"\n",
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
" Array([254, 255, 255, 255], dtype=uint8)\n",
"\n",
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
"\n",

View File

@ -1143,6 +1143,7 @@ Many such cases are discussed in detail in the sections above; here we list seve
>>> jnp.arange(254.0, 258.0).astype('uint8')
Array([254, 255, 255, 255], dtype=uint8)
```
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.

View File

@ -226,7 +226,7 @@ context manager:
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL
... z = x + y # doctest: +SKIP
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit