diff --git a/docs/export/export.md b/docs/export/export.md index b62cf9fe0..aa686b03e 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -247,7 +247,7 @@ for which the code was exported. You can specify explicitly for what platforms the code should be exported. This allows you to specify a different accelerator than you have available at export time, -and it even allows you to specify multi-platform lexport to +and it even allows you to specify multi-platform export to obtain an `Exported` object that can be compiled and executed on multiple platforms. @@ -293,7 +293,7 @@ resulting module size should be only marginally larger than the size of a module with default export. As an extreme case, when serializing a module without any primitives with platform-specific lowering, you will get -the same StableHLO as for the single-plaform export. +the same StableHLO as for the single-platform export. ```python >>> import jax diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index b1ce80638..6ad7fb5c2 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -44,7 +44,7 @@ following example: ``` Note that such functions are still re-compiled on demand for -each concrete input shapes they are invoked on. Only the +each concrete input shape they are invoked on. Only the tracing and the lowering are saved. The {func}`jax.export.symbolic_shape` is used in the above @@ -98,7 +98,7 @@ A few examples of shape specifications: arguments. Note that the same specification would work if the first argument is a pytree of 3D arrays, all with the same leading dimension but possibly with different trailing dimensions. - The value `None` for the second arugment means that the argument + The value `None` for the second argument means that the argument is not symbolic. Equivalently, one can use `...`. * `("(batch, ...)", "(batch,)")` specifies that the two arguments @@ -256,7 +256,7 @@ as follows: integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`, `a >= b`, `a - b >= 0` are inconclusive and result in an exception. -In cases where a comparison operation cannot be resolve to a boolean, +In cases where a comparison operation cannot be resolved to a boolean, we raise {class}`InconclusiveDimensionOperation`. E.g., ```python @@ -351,7 +351,7 @@ symbolic constraints: is encountered, it is rewritten to the expression on the right. E.g., `floordiv(a, b) == c` works by replacing all - occurences of `floordiv(a, b)` with `c`. + occurrences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or @@ -498,11 +498,11 @@ This works well for most use cases, and it mirrors the calling convention of JIT functions. Sometimes you may want to export a function parameterized -by an integer values that determines some shapes in the program. +by an integer value that determines some shapes in the program. For example, we may want to export the function `my_top_k` defined below, parameterized by the -value of `k`, which determined the shape of the result. +value of `k`, which determines the shape of the result. The following attempt will lead to an error since the dimension variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index caf63df17..da33a677b 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -237,7 +237,7 @@ params_vars = tf.nest.map_structure(tf.Variable, params) prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs) my_model = tf.Module() -# Tell the model saver what are the variables. +# Tell the model saver what the variables are. my_model._variables = tf.nest.flatten(params_vars) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False) tf.saved_model.save(my_model) @@ -760,7 +760,7 @@ symbolic constraints: We plan to improve somewhat this area in the future. * Equality constraints are treated as normalization rules. E.g., `floordiv(a, b) = c` works by replacing all - occurences of the left-hand-side with the right-hand-side. + occurrences of the left-hand-side with the right-hand-side. You can only have equality constraints where the left-hand-side is a multiplication of factors, e.g, `a * b`, or `4 * a`, or `floordiv(a, b)`. Thus, the left-hand-side cannot contain @@ -1048,7 +1048,7 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)) ``` -When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types +When the `JAX_ENABLE_X64` flag is set, JAX uses 64-bit types for Python scalars and respects the explicit 64-bit types: ```python @@ -1245,7 +1245,7 @@ Applies to both native and non-native serialization. trackable classes during attribute assignment. Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper classes. -In most situation, these Wrapper classes work exactly as the standard +In most situations, these Wrapper classes work exactly as the standard Python data types. However, the low-level pytree data structures are different and this can lead to errors. @@ -1499,7 +1499,7 @@ during lowering we try to generate one TensorFlow op for one JAX primitive. We expect that the lowering that XLA does is similar to that done by JAX before conversion. (This is a hypothesis, we have not yet verified it extensively.) -There is one know case when the performance of the lowered code will be different. +There is one known case when the performance of the lowered code will be different. JAX programs use a [stateless deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it.