mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[docs] Enable doctest for .md documentation files
The current invocation of doctest from GH actions picked up only .rst files. We enable .md files also, and we make a few changes to ensure that the doctest passes on the existing files. The changes fall into several categories: * add a newline before the end of the code block, for doctest to pick up the expected output properly * update the expected values to match the current behavior * disable some doctests that raise expected exceptions, whenever I could not get doctest to match the exception details. Sometimes +IGNORE_EXCEPTION_DETAIL was enough, and other times I had to use +SKIP.
This commit is contained in:
parent
e81c82605f
commit
0223e830c1
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -139,7 +139,7 @@ jobs:
|
||||
JAX_ARRAY: 1
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short docs
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
|
||||
|
||||
|
||||
|
50
docs/aot.md
50
docs/aot.md
@ -45,12 +45,12 @@ way. An example:
|
||||
|
||||
>>> # Print lowered HLO
|
||||
>>> print(lowered.as_text())
|
||||
module @jit_f.0 {
|
||||
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
%0 = stablehlo.constant dense<2> : tensor<i32>
|
||||
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
|
||||
%2 = stablehlo.add %1, %arg1 : tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
%c = stablehlo.constant dense<2> : tensor<i32>
|
||||
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
|
||||
%1 = stablehlo.add %0, %arg1 : tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +62,8 @@ module @jit_f.0 {
|
||||
|
||||
>>> # Execute the compiled function!
|
||||
>>> compiled(x, y)
|
||||
DeviceArray(10, dtype=int32)
|
||||
Array(10, dtype=int32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
See the {mod}`jax.stages` documentation for more details on what functionality
|
||||
@ -83,7 +84,8 @@ that have `shape` and `dtype` attributes:
|
||||
```python
|
||||
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
|
||||
DeviceArray(10, dtype=int32)
|
||||
Array(10, dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
More generally, `lower` only needs its arguments to structurally supply what JAX
|
||||
@ -97,18 +99,21 @@ lowering raises an error:
|
||||
|
||||
```python
|
||||
>>> x_1d = y_1d = jnp.arange(3)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
Argument 'x' compiled with int32[] and called with int32[3]
|
||||
Argument 'y' compiled with int32[] and called with int32[3]
|
||||
|
||||
>>> x_f = y_f = jnp.float32(72.)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
Argument 'x' compiled with int32[] and called with float32[]
|
||||
Argument 'y' compiled with int32[] and called with float32[]
|
||||
|
||||
```
|
||||
|
||||
Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time
|
||||
@ -127,15 +132,16 @@ to invoke the resulting compiled function. Continuing with our example above:
|
||||
|
||||
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
|
||||
>>> print(lowered_with_x.as_text())
|
||||
module @jit_f.1 {
|
||||
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = stablehlo.constant dense<14> : tensor<i32>
|
||||
%1 = stablehlo.add %0, %arg0 : tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
%c = stablehlo.constant dense<14> : tensor<i32>
|
||||
%0 = stablehlo.add %c, %arg0 : tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
}
|
||||
>>> lowered_with_x.compile()(5)
|
||||
DeviceArray(19, dtype=int32)
|
||||
Array(19, dtype=int32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
Note that `lower` here takes two arguments as usual, but the subsequent compiled
|
||||
@ -149,11 +155,13 @@ shape/dtype structure, it is necessary that the static first argument be a
|
||||
concrete value. Otherwise, lowering would err:
|
||||
|
||||
```python
|
||||
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
|
||||
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
|
||||
|
||||
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
|
||||
DeviceArray(25, dtype=int32)
|
||||
Array(25, dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
## AOT-compiled functions cannot be transformed
|
||||
@ -179,13 +187,15 @@ in transformations. Example:
|
||||
>>> g_aot = jax.jit(g).lower(z).compile()
|
||||
|
||||
>>> jax.vmap(g_jit)(zs)
|
||||
DeviceArray([[ 1., 5., 9.],
|
||||
Array([[ 1., 5., 9.],
|
||||
[13., 17., 21.],
|
||||
[25., 29., 33.],
|
||||
[37., 41., 45.]], dtype=float32)
|
||||
|
||||
>>> jax.vmap(g_aot)(zs)
|
||||
>>> jax.vmap(g_aot)(zs) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
|
||||
|
||||
```
|
||||
|
||||
A similar error is raised when `g_aot` is involved in autodiff
|
||||
|
@ -21,6 +21,7 @@ Array([0, 0], dtype=uint32)
|
||||
(2,)
|
||||
>>> key.dtype
|
||||
dtype('uint32')
|
||||
|
||||
```
|
||||
Starting now, new-style RNG keys can be created with
|
||||
{func}`jax.random.key`:
|
||||
@ -33,6 +34,7 @@ Array((), dtype=key<fry>) overlaying:
|
||||
()
|
||||
>>> key.dtype
|
||||
key<fry>
|
||||
|
||||
```
|
||||
This (scalar-shaped) array behaves the same as any other JAX array, except
|
||||
that its element type is a key (and associated metadata). We can make
|
||||
@ -48,6 +50,7 @@ Array((4,), dtype=key<fry>) overlaying:
|
||||
[0 3]]
|
||||
>>> key_arr.shape
|
||||
(4,)
|
||||
|
||||
```
|
||||
Aside from switching to a new constructor, most PRNG-related code should
|
||||
continue to work as expected. You can continue to use keys in
|
||||
@ -62,14 +65,17 @@ data = jax.random.uniform(key, shape=(5,))
|
||||
However, not all numerical operations work on key arrays. They now
|
||||
intentionally raise errors:
|
||||
```python
|
||||
>>> key = key + 1
|
||||
ValueError: dtype=key<fry> is not a valid dtype for JAX type promotion.
|
||||
>>> key = key + 1 # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: add does not accept dtypes key<fry>, int32.
|
||||
|
||||
```
|
||||
If for some reason you need to recover the underlying buffer
|
||||
(the old-style key), you can do so with {func}`jax.random.key_data`:
|
||||
```python
|
||||
>>> jax.random.key_data(key)
|
||||
Array([0, 0], dtype=uint32)
|
||||
|
||||
```
|
||||
For old-style keys, {func}`~jax.random.key_data` is an identity operation.
|
||||
|
||||
@ -108,6 +114,7 @@ True
|
||||
>>> raw_key = jax.random.PRNGKey(0)
|
||||
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
|
||||
False
|
||||
|
||||
```
|
||||
|
||||
### Type annotations for PRNG Keys
|
||||
@ -173,6 +180,7 @@ Array((), dtype=key<rbg>) overlaying:
|
||||
[0 0 0 0]
|
||||
>>> jax.random.uniform(key, shape=(3,))
|
||||
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
|
||||
|
||||
```
|
||||
|
||||
### Safe PRNG key use
|
||||
@ -322,6 +330,7 @@ which has the following property:
|
||||
```python
|
||||
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
|
||||
True
|
||||
|
||||
```
|
||||
PRNG key arrays then have a dtype with the following properties:
|
||||
```python
|
||||
@ -330,6 +339,7 @@ PRNG key arrays then have a dtype with the following properties:
|
||||
True
|
||||
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
|
||||
True
|
||||
|
||||
```
|
||||
And in addition to `key.dtype._rules` as outlined for extended dtypes in
|
||||
general, PRNG dtypes define `key.dtype._impl`, which contains the metadata
|
||||
|
@ -2223,6 +2223,7 @@
|
||||
"\n",
|
||||
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
|
||||
" Array([254, 255, 255, 255], dtype=uint8)\n",
|
||||
"\n",
|
||||
" ```\n",
|
||||
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
|
||||
"\n",
|
||||
|
@ -1143,6 +1143,7 @@ Many such cases are discussed in detail in the sections above; here we list seve
|
||||
|
||||
>>> jnp.arange(254.0, 258.0).astype('uint8')
|
||||
Array([254, 255, 255, 255], dtype=uint8)
|
||||
|
||||
```
|
||||
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
|
||||
|
||||
|
@ -226,7 +226,7 @@ context manager:
|
||||
>>> x = jnp.float32(1)
|
||||
>>> y = jnp.int32(1)
|
||||
>>> with jax.numpy_dtype_promotion('strict'):
|
||||
... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
... z = x + y # doctest: +SKIP
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
|
||||
|
Loading…
x
Reference in New Issue
Block a user