From 0223e830c1032a0f18274d880a7a0ef734bd91a6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 3 Jun 2024 11:27:34 +0300 Subject: [PATCH] [docs] Enable doctest for .md documentation files The current invocation of doctest from GH actions picked up only .rst files. We enable .md files also, and we make a few changes to ensure that the doctest passes on the existing files. The changes fall into several categories: * add a newline before the end of the code block, for doctest to pick up the expected output properly * update the expected values to match the current behavior * disable some doctests that raise expected exceptions, whenever I could not get doctest to match the exception details. Sometimes +IGNORE_EXCEPTION_DETAIL was enough, and other times I had to use +SKIP. --- .github/workflows/ci-build.yaml | 2 +- docs/aot.md | 56 +++++++++++++--------- docs/jep/9263-typed-keys.md | 14 +++++- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 1 + docs/notebooks/Common_Gotchas_in_JAX.md | 1 + docs/type_promotion.rst | 2 +- 6 files changed, 49 insertions(+), 27 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 3ece6976c..566ecaca7 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -139,7 +139,7 @@ jobs: JAX_ARRAY: 1 PY_COLORS: 1 run: | - pytest -n auto --tb=short docs + pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py diff --git a/docs/aot.md b/docs/aot.md index 3304f4081..e1420f702 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -45,12 +45,12 @@ way. An example: >>> # Print lowered HLO >>> print(lowered.as_text()) -module @jit_f.0 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = stablehlo.constant dense<2> : tensor - %1 = stablehlo.multiply %0, %arg0 : tensor - %2 = stablehlo.add %1, %arg1 : tensor - return %2 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}, %arg1: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor + %0 = stablehlo.multiply %c, %arg0 : tensor + %1 = stablehlo.add %0, %arg1 : tensor + return %1 : tensor } } @@ -62,7 +62,8 @@ module @jit_f.0 { >>> # Execute the compiled function! >>> compiled(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32, weak_type=True) + ``` See the {mod}`jax.stages` documentation for more details on what functionality @@ -83,7 +84,8 @@ that have `shape` and `dtype` attributes: ```python >>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32')) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y) -DeviceArray(10, dtype=int32) +Array(10, dtype=int32) + ``` More generally, `lower` only needs its arguments to structurally supply what JAX @@ -97,18 +99,21 @@ lowering raises an error: ```python >>> x_1d = y_1d = jnp.arange(3) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with int32[3] Argument 'y' compiled with int32[] and called with int32[3] >>> x_f = y_f = jnp.float32(72.) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) +>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL ... +Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with float32[] Argument 'y' compiled with int32[] and called with float32[] + ``` Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time @@ -127,15 +132,16 @@ to invoke the resulting compiled function. Continuing with our example above: >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) -module @jit_f.1 { - func.func public @main(%arg0: tensor) -> tensor { - %0 = stablehlo.constant dense<14> : tensor - %1 = stablehlo.add %0, %arg0 : tensor - return %1 : tensor +module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<14> : tensor + %0 = stablehlo.add %c, %arg0 : tensor + return %0 : tensor } } >>> lowered_with_x.compile()(5) -DeviceArray(19, dtype=int32) +Array(19, dtype=int32, weak_type=True) + ``` Note that `lower` here takes two arguments as usual, but the subsequent compiled @@ -149,11 +155,13 @@ shape/dtype structure, it is necessary that the static first argument be a concrete value. Otherwise, lowering would err: ```python ->>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) +>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP +Traceback (most recent call last): TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct' >>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5) -DeviceArray(25, dtype=int32) +Array(25, dtype=int32) + ``` ## AOT-compiled functions cannot be transformed @@ -179,13 +187,15 @@ in transformations. Example: >>> g_aot = jax.jit(g).lower(z).compile() >>> jax.vmap(g_jit)(zs) -DeviceArray([[ 1., 5., 9.], - [13., 17., 21.], - [25., 29., 33.], - [37., 41., 45.]], dtype=float32) +Array([[ 1., 5., 9.], + [13., 17., 21.], + [25., 29., 33.], + [37., 41., 45.]], dtype=float32) ->>> jax.vmap(g_aot)(zs) +>>> jax.vmap(g_aot)(zs) # doctest: +SKIP +Traceback (most recent call last): TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type . + ``` A similar error is raised when `g_aot` is involved in autodiff diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 925fc2c47..828b95e8c 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -21,6 +21,7 @@ Array([0, 0], dtype=uint32) (2,) >>> key.dtype dtype('uint32') + ``` Starting now, new-style RNG keys can be created with {func}`jax.random.key`: @@ -33,6 +34,7 @@ Array((), dtype=key) overlaying: () >>> key.dtype key + ``` This (scalar-shaped) array behaves the same as any other JAX array, except that its element type is a key (and associated metadata). We can make @@ -48,6 +50,7 @@ Array((4,), dtype=key) overlaying: [0 3]] >>> key_arr.shape (4,) + ``` Aside from switching to a new constructor, most PRNG-related code should continue to work as expected. You can continue to use keys in @@ -62,14 +65,17 @@ data = jax.random.uniform(key, shape=(5,)) However, not all numerical operations work on key arrays. They now intentionally raise errors: ```python ->>> key = key + 1 -ValueError: dtype=key is not a valid dtype for JAX type promotion. +>>> key = key + 1 # doctest: +SKIP +Traceback (most recent call last): +TypeError: add does not accept dtypes key, int32. + ``` If for some reason you need to recover the underlying buffer (the old-style key), you can do so with {func}`jax.random.key_data`: ```python >>> jax.random.key_data(key) Array([0, 0], dtype=uint32) + ``` For old-style keys, {func}`~jax.random.key_data` is an identity operation. @@ -108,6 +114,7 @@ True >>> raw_key = jax.random.PRNGKey(0) >>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key) False + ``` ### Type annotations for PRNG Keys @@ -173,6 +180,7 @@ Array((), dtype=key) overlaying: [0 0 0 0] >>> jax.random.uniform(key, shape=(3,)) Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32) + ``` ### Safe PRNG key use @@ -322,6 +330,7 @@ which has the following property: ```python >>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended) True + ``` PRNG key arrays then have a dtype with the following properties: ```python @@ -330,6 +339,7 @@ PRNG key arrays then have a dtype with the following properties: True >>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key) True + ``` And in addition to `key.dtype._rules` as outlined for extended dtypes in general, PRNG dtypes define `key.dtype._impl`, which contains the metadata diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index d8dffdb8a..b52dc2176 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -2223,6 +2223,7 @@ "\n", " >>> jnp.arange(254.0, 258.0).astype('uint8')\n", " Array([254, 255, 255, 255], dtype=uint8)\n", + "\n", " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index e63d64d94..a46a07b5f 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -1143,6 +1143,7 @@ Many such cases are discussed in detail in the sections above; here we list seve >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8) + ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index c70371478..103a8331d 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -226,7 +226,7 @@ context manager: >>> x = jnp.float32(1) >>> y = jnp.int32(1) >>> with jax.numpy_dtype_promotion('strict'): - ... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL + ... z = x + y # doctest: +SKIP ... Traceback (most recent call last): TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit