diff --git a/docs/faq.rst b/docs/faq.rst index 941c6c147..0f8709b4f 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -304,7 +304,7 @@ treated as dynamic. Here's how it might look:: ... CustomClass._tree_unflatten) This is certainly more involved, but it solves all the issues associated with the simpler -apporaches used above:: +approaches used above:: >>> c = CustomClass(2, True) >>> print(c.calc(3)) @@ -442,7 +442,7 @@ When run with a GPU in Colab_, we see: - JAX takes 193 ms to compile the function - JAX takes 485 µs per evaluation on the GPU -In this case, we see that once the data is transfered and the function is +In this case, we see that once the data is transferred and the function is compiled, JAX on the GPU is about 30x faster for repeated evaluations. Is this a fair comparison? Maybe. The performance that ultimately matters is for diff --git a/docs/jax-101/01-jax-basics.ipynb b/docs/jax-101/01-jax-basics.ipynb index 20aba4440..e4521c70a 100644 --- a/docs/jax-101/01-jax-basics.ipynb +++ b/docs/jax-101/01-jax-basics.ipynb @@ -620,7 +620,7 @@ "\n", "Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.\n", "\n", - "We will explain other places where the JAX idiosyncracies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs." + "We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs." ] }, { diff --git a/docs/jax-101/01-jax-basics.md b/docs/jax-101/01-jax-basics.md index 4431c085a..f2ac47cdc 100644 --- a/docs/jax-101/01-jax-basics.md +++ b/docs/jax-101/01-jax-basics.md @@ -292,7 +292,7 @@ Isn't the pure version less efficient? Strictly, yes; we are creating a new arra Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that. -We will explain other places where the JAX idiosyncracies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs. +We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs. +++ {"id": "dFn_VBFFlGCz"} diff --git a/docs/jax-101/05.1-pytrees.ipynb b/docs/jax-101/05.1-pytrees.ipynb index 0800f646c..d80ff386a 100644 --- a/docs/jax-101/05.1-pytrees.ipynb +++ b/docs/jax-101/05.1-pytrees.ipynb @@ -392,7 +392,7 @@ "\n", "* `SequenceKey(idx: int)`: for lists and tuples.\n", "* `DictKey(key: Hashable)`: for dictionaries.\n", - "* `GetAttrKey(name: str)`: for `namedtuple`s and preferrably custom pytree nodes (more in the next section)\n", + "* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)\n", "\n", "You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression." ] diff --git a/docs/jax-101/05.1-pytrees.md b/docs/jax-101/05.1-pytrees.md index b01d937ad..b806eacd8 100644 --- a/docs/jax-101/05.1-pytrees.md +++ b/docs/jax-101/05.1-pytrees.md @@ -233,7 +233,7 @@ To express key paths, JAX provides a few default key types for the built-in pytr * `SequenceKey(idx: int)`: for lists and tuples. * `DictKey(key: Hashable)`: for dictionaries. -* `GetAttrKey(name: str)`: for `namedtuple`s and preferrably custom pytree nodes (more in the next section) +* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section) You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression. diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 874bdb796..d1a889418 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -166,7 +166,7 @@ Inputs to JAX functions and methods should be typed as permissively as is reason - `ArrayLike` would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalars - `DTypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types. -- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objecs. +- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objects. - etc. Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DTypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog. @@ -177,7 +177,7 @@ Conversely, outputs of functions and methods should be typed as strictly as poss - `Array` or `NDArray` (see below) for type annotation purposes is effectively equivalent to `Union[Tracer, jnp.ndarray]` and should be used to annotate array outputs. - `DType` is an alias of `np.dtype`, perhaps with the ability to also represent key types and other generalizations used within JAX. -- `Shape` is essentially `Tuple[int, ...]`, perhaps with some additional flexibilty to account for dynamic shapes. +- `Shape` is essentially `Tuple[int, ...]`, perhaps with some additional flexibility to account for dynamic shapes. - `NamedShape` is an extension of `Shape` that allows for named shapes as used internall in JAX. - etc. @@ -283,7 +283,7 @@ Finally, we could opt for full unification via restructuring of the class hierar Here `jnp.ndarray` could be an alias for `jax.Array`. This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (`Tracer` *is an* `Array`?). -##### Option 4: Parial unification via class hierarchy +##### Option 4: Partial unification via class hierarchy We could make the class hierarchy more sensible by making `Tracer` and the class for on-device arrays inherit from a common base class. So, for example: diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 47da5ce87..2fdf5f822 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -505,7 +505,7 @@ than an `int`). Here are reasons we like unmapped inputs and outputs for `shmap`: * **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape hatch should be able to do too. Or else we'd have a lacking escape hatch! If - we didnt have unmapped outputs in `shmap` then we couldn't express the same + we didn't have unmapped outputs in `shmap` then we couldn't express the same batch-parallel loss function computations as `pjit`. * **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped inputs, and... diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 51873d209..ebaec8235 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -946,7 +946,7 @@ "source": [ "### How to handle `uint64`?\n", "\n", - "The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer opertion involving `uint64` should result in `int128`, but this is not a standard available dtype.\n", + "The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.\n", "\n", "Numpy's choice here is to promote to `float64`:" ] @@ -1919,7 +1919,7 @@ "source": [ "### Tensorflow Type Promotion\n", "\n", - "Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercable to the type of `x`." + "Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercible to the type of `x`." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index c8c18d887..b6a5d4fc2 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -457,7 +457,7 @@ Again, the connections added here are precisely the promotion semantics implemen ### How to handle `uint64`? -The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer opertion involving `uint64` should result in `int128`, but this is not a standard available dtype. +The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype. Numpy's choice here is to promote to `float64`: @@ -796,7 +796,7 @@ display.HTML(table.to_html()) ### Tensorflow Type Promotion -Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercable to the type of `x`. +Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercible to the type of `x`. ```{code-cell} :cellView: form diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 610ceb659..9aec8b1a2 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -203,7 +203,7 @@ "id": "40oy-FbmVkDc" }, "source": [ - "When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility definied in this notebook:" + "When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility defined in this notebook:" ] }, { @@ -359,7 +359,7 @@ "\n", "\n", "\n", - "In both `jax.linearize` and `jax.vjp` there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.\n", + "In both `jax.linearize` and `jax.vjp` there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.\n", "\n", "One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`:" ] diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index f7ffef9ba..d8486a0b5 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -99,7 +99,7 @@ jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x) +++ {"id": "40oy-FbmVkDc"} -When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility definied in this notebook: +When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility defined in this notebook: ```{code-cell} from jax.tree_util import tree_flatten, tree_unflatten @@ -162,7 +162,7 @@ You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.read -In both `jax.linearize` and `jax.vjp` there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`. +In both `jax.linearize` and `jax.vjp` there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`. One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`: diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 1ce077ca7..f8dcaa368 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -129,7 +129,7 @@ "ax[0].imshow(image, cmap='binary_r')\n", "ax[0].set_title('original')\n", "\n", - "# Create a noisy version by adding random Gausian noise\n", + "# Create a noisy version by adding random Gaussian noise\n", "key = random.PRNGKey(1701)\n", "noisy_image = image + 50 * random.normal(key, image.shape)\n", "ax[1].imshow(noisy_image, cmap='binary_r')\n", @@ -722,7 +722,7 @@ "# N,H,W,C = img.shape\n", "# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))\n", "\n", - "# transposed conv = 180deg kernel roation plus LHS dilation\n", + "# transposed conv = 180deg kernel rotation plus LHS dilation\n", "# rotate kernel 180deg:\n", "kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))\n", "# need a custom output padding:\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 15940a4cb..ce5b5a4aa 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -83,7 +83,7 @@ image = jnp.array(misc.face().mean(-1)) ax[0].imshow(image, cmap='binary_r') ax[0].set_title('original') -# Create a noisy version by adding random Gausian noise +# Create a noisy version by adding random Gaussian noise key = random.PRNGKey(1701) noisy_image = image + 50 * random.normal(key, image.shape) ax[1].imshow(noisy_image, cmap='binary_r') @@ -330,7 +330,7 @@ We can use the last to, for instance, implement _transposed convolutions_: # N,H,W,C = img.shape # out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1)) -# transposed conv = 180deg kernel roation plus LHS dilation +# transposed conv = 180deg kernel rotation plus LHS dilation # rotate kernel 180deg: kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1)) # need a custom output padding: diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index e55d35bcf..5cda80620 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -333,7 +333,7 @@ "id": "LrvdAloMZbIe" }, "source": [ - "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may elliminate the callback entirely:" + "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:" ] }, { @@ -398,7 +398,7 @@ "\n", "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n", "\n", - "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generaing a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)." + "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)." ] }, { diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 9364e001d..e87c360f5 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -171,7 +171,7 @@ For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pu +++ {"id": "LrvdAloMZbIe"} -By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may elliminate the callback entirely: +By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely: ```{code-cell} :id: mmFc_zawZrBq @@ -208,7 +208,7 @@ In `f2` on the other hand, the output of the callback is unused, and so the comp In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects. -As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generaing a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!). +As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!). ```{code-cell} :id: eAg5xIhrOiWV diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index c77f195f7..8558d7038 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -173,7 +173,7 @@ "A 2D grid\n", "\n", "\n", - "When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invokations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invokation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n", + "When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n", "\n", "Here's an example kernel that uses a `grid` and `program_id`." ] diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 54438bac7..25b1a5899 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -112,7 +112,7 @@ We run the kernel function once for each element, a style of single-program mult A 2D grid -When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invokations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invokation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. +When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. Here's an example kernel that uses a `grid` and `program_id`. diff --git a/docs/pallas/tpu.rst b/docs/pallas/tpu.rst index a8270eb62..cefd3f657 100644 --- a/docs/pallas/tpu.rst +++ b/docs/pallas/tpu.rst @@ -86,7 +86,7 @@ with compute. What's more, compared to GPUs, TPUs are actually highly sequential machines. That's why, the grid is generally not processed in parallel, but sequentially, in lexicographic order (though see the `Multicore TPU configurations`_ section -for exceptions). This unlocks some interesting capabilites: +for exceptions). This unlocks some interesting capabilities: * When two (lexicographically) consecutive grid indices use the same slice of an input, the HBM transfer in the second iteration is skipped, as the data is diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 983075c4c..804019aa3 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -197,7 +197,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj): # be part of the cache key as their inclusion will result in unnecessary cache # misses. Clear them here by setting bool values to False, ints to 0, and # strings to empty. The exact values used to clear are not relevant as long - # as the same values are used everytime for each field. + # as the same values are used every time for each field. debug_options = compile_options_copy.executable_build_options.debug_options # LINT.IfChange(debug_options) debug_options.xla_force_host_platform_device_count = 0 diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 50c72626d..5a2f0e774 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -82,7 +82,7 @@ class ClusterEnv: """Returns address and port used by JAX to bootstrap. Process id 0 will open a tcp socket at "hostname:port" where - all the proccesses will connect to initialize the distributed JAX service. + all the processes will connect to initialize the distributed JAX service. The selected port needs to be free. :func:`get_coordinator_address` needs to return the same hostname and port on all the processes. diff --git a/jax/_src/config.py b/jax/_src/config.py index 3fbce2ce3..1c9eb5be2 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -801,7 +801,7 @@ spmd_mode = config.define_enum_state( " to execute on non-fully addressable `jax.Array`s\n" "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " " `jax.jit` and all other operations are allowed to " - " execute on non-fully addresable `jax.Array`s.")) + " execute on non-fully addressable `jax.Array`s.")) distributed_debug = config.define_bool_state( diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 8d89281cf..d4b2cf8a0 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1198,7 +1198,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, Args: fun: a Python callable specifying a linear function. It should take two arguments: one of "residual" inputs (type ``r``), - i.e. inputs in which the function is not necessarly linear, and + i.e. inputs in which the function is not necessarily linear, and one of "linear" inputs (type ``a``). It should return output whose components are linear in the linear input (type ``b``). fun_transpose: a Python callable specifying a structurally linear diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 7885f4ebc..d13a85880 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -165,7 +165,7 @@ def initialize(coordinator_address: Optional[str] = None, Example: - Suppose there are two GPU processs, and process 0 is the designated coordinator + Suppose there are two GPU processes, and process 0 is the designated coordinator with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the following commands before anything else. diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 96641719f..5594b261a 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -535,7 +535,7 @@ class UnexpectedTracerError(JAXTypeError): JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an ``UnexpectedTracerError``. To fix this, avoid side effects: if a function computes a value needed - in an outer scope, return that value from the transformed function explictly. + in an outer scope, return that value from the transformed function explicitly. Specifically, a ``Tracer`` is JAX's internal representation of a function's intermediate values during transformations, e.g. within :func:`~jax.jit`, diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 1bb62194a..84e616a62 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -406,7 +406,7 @@ class BatchTrace(Trace): else: axis_size = None # can't be inferred from data if self.axis_name is core.no_axis_name: - assert axis_size is not None # must be inferrable from data + assert axis_size is not None # must be inferable from data return core.AxisEnvFrame(self.axis_name, axis_size, self.main) frame = core.axis_frame(self.axis_name, self.main) assert axis_size is None or axis_size == frame.size, (axis_size, frame.size) @@ -1063,7 +1063,7 @@ def _mask_one_ragged_axis( ragged_axis, segment_lengths = axis_spec.ragged_axes[0] value = ident(operand.dtype) positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis) - # TODO(mattjj, axch) cant get ._data, need to convert it + # TODO(mattjj, axch) can't get ._data, need to convert it # lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32') lengths = jax.lax.convert_element_type(segment_lengths, 'int32') limits = jax.lax.broadcast_in_dim( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a3cf0c4e9..67db40e6d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1058,7 +1058,7 @@ def partial_eval_jaxpr_nounits( In the above example, the output of jaxpr_known named `d` is a _residual_ output, and corresponds to the input named `a` in jaxpr_unknown. In general, jaxpr_known will produce extra outputs (at the end of its output list) - corresponding to intermeidate values of the original jaxpr which must be + corresponding to intermediate values of the original jaxpr which must be passed to jaxpr_unknown (as leading inputs). """ instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate @@ -2411,7 +2411,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]: invars = [*jaxpr.constvars, *jaxpr.invars] expl_outvars = jaxpr.outvars - # First do a pass to collect implicit outputs, meaning variables which occurr + # First do a pass to collect implicit outputs, meaning variables which occur # in explicit_outvars types but not in invars or to the left in outvars. seen: set[Var] = set(invars) impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index da89294d4..aa74a151e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2856,7 +2856,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): in_shardings = [gs] * num_in_shardings out_shardings = [gs] * num_out_shardings # jit(pmap) will generate Arrays with multi-device sharding. - # It is unsupported for these shardings to be uncommited, so force + # It is unsupported for these shardings to be uncommitted, so force # the outputs to be committed. committed = True return in_shardings, out_shardings, committed, tuple(local_devices) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index b00fb0fb3..ac1b5aec1 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -203,7 +203,7 @@ def approx_min_k(operand: Array, In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of ``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less - arithmetics and produces the same set of neighbors. + arithmetic and produces the same set of neighbors. """ return approx_top_k_p.bind( operand, diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 9a67b4a0b..4a93b8b68 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -719,7 +719,7 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr: def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll): # if any in_ct is nonzero, we definitely want it in args_ (and the - # corresponding x in args could be an undefined primal, but doesnt have to be) + # corresponding x in args could be an undefined primal, but doesn't have to be) # for non-res stuff: # getting and setting => (nonzero ct, UndefinedPrimal arg) # just setting => (nonzero ct, not UndefinedPrimal, dummy value) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c9ac5f34c..1396581b0 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3286,7 +3286,7 @@ mlir.register_lowering(pad_p, _pad_lower) # For N > 1, we can match up the output array axis with the second axis of the # input. But for N = 1, it is not clear how axes match up: all we know from the # JAXpr is that we are reshaping from (1, 1) to (1,). -# In constrast, squeeze[ dimensions=(0,) ] is unambiguous. +# In contrast, squeeze[ dimensions=(0,) ] is unambiguous. def _squeeze_dtype_rule(operand, *, dimensions): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 907fe6cf4..864e40464 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2038,7 +2038,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``. Note that ``dl[0] = 0``. d: A batch of vectors with shape ``[..., m]``. - The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``. + The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``. du: A batch of vectors with shape ``[..., m]``. The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``. Note that ``dl[m - 1] = 0``. diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 8a812e553..e862009f6 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -605,7 +605,7 @@ def bessel_i0e_impl(x): elif x.dtype == np.float32: return _i0e_impl32(x) else: - # Have to upcast f16 because the magic Cephes coefficents don't have enough + # Have to upcast f16 because the magic Cephes coefficients don't have enough # precision for it. x_dtype = x.dtype x = x.astype(np.float32) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 91edd0b43..cb42186cd 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -55,7 +55,7 @@ zip, unsafe_zip = safe_zip, zip def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: """Copy the array and cast to a specified dtype. - This is implemeted via :func:`jax.lax.convert_element_type`, which may + This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :meth:`numpy.ndarray.astype` in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. @@ -346,7 +346,7 @@ class _IndexUpdateHelper: """Helper property for index update functionality. The ``at`` property provides a functionally pure equivalent of in-place - array modificatons. + array modifications. In particular: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 06aaf05a6..59b657fec 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1840,7 +1840,7 @@ def with_sharding_constraint(x, shardings): of how to use this function, see `Distributed arrays and automatic parallelization`_. Args: - x: PyTree of jax.Arrays which will have their shardings constrainted + x: PyTree of jax.Arrays which will have their shardings constrained shardings: PyTree of sharding specifications. Valid values are the same as for the ``in_shardings`` argument of :func:`jax.experimental.pjit`. Returns: diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f1d7d3832..f2e5aa34b 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -104,7 +104,7 @@ def start_trace(log_dir, create_perfetto_link: bool = False, Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you want to generate a Perfetto-compatible trace without blocking the - processs. + process. """ with _profile_state.lock: if _profile_state.profile_session is not None: @@ -228,7 +228,7 @@ def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False): Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you want to generate a Perfetto-compatible trace without blocking the - processs. + process. """ start_trace(log_dir, create_perfetto_link, create_perfetto_trace) try: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 680a20954..b00471894 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -865,7 +865,7 @@ def _gen_derivatives(p: Array, Args: p: The 3D array containing the values of associated Legendre functions; the - dimensions are in the sequence of order (m), degree (l), and evalution + dimensions are in the sequence of order (m), degree (l), and evaluation points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. @@ -962,7 +962,7 @@ def _gen_associated_legendre(l_max: int, harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, - repectively. `N_l^m` is chosen in the way that the spherical harmonics form + respectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` @@ -999,7 +999,7 @@ def _gen_associated_legendre(l_max: int, Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and - evalution points. + evaluation points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype) @@ -1106,7 +1106,7 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array: spherical harmonic of degree `l` and order `m` can be written as :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`, where :math:`N_l^m` is the normalization factor and θ and φ are the - colatitude and longitude, repectively. :math:`N_l^m` is chosen in the + colatitude and longitude, respectively. :math:`N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow and achieves better numerical stability. @@ -1192,7 +1192,7 @@ def sph_harm(m: Array, :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`, where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and - :math:`\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is + :math:`\theta` are the colatitude and longitude, respectively. :math:`N_n^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. @@ -1600,7 +1600,7 @@ def expn_jvp(n, primals, tangents): @_wraps(osp_special.exp1, module="scipy.special") def exp1(x: ArrayLike, module='scipy.special') -> Array: x, = promote_args_inexact("exp1", x) - # Casting becuase custom_jvp generic does not work correctly with mypy. + # Casting because custom_jvp generic does not work correctly with mypy. return cast(Array, expn(1, x)) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 768b6e7f5..8a38b6221 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -477,7 +477,7 @@ class Compiled(Stage): # This is because `__call__` passes in `self._params` as the first argument. # Instead of making the call signature `call(params, *args, **kwargs)` # extract it from args because `params` can be passed as a kwarg by users - # which might confict here. + # which might conflict here. params = args[0] args = args[1:] if jax.config.jax_dynamic_shapes: diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index ae4275edb..d698398b8 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -629,7 +629,7 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool] def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): # if any in_ct is nonzero, we definitely want it in args_ (and the - # corresponding x in args could be an undefined primal, but doesnt have to be) + # corresponding x in args could be an undefined primal, but doesn't have to be) # for non-res stuff: # getting and setting => (nonzero ct, UndefinedPrimal arg) # just setting => (nonzero ct, not UndefinedPrimal, dummy value) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index a321685aa..2c07930a6 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -216,7 +216,7 @@ def _get_pjrt_plugin_names_and_library_paths( """Gets the names and library paths of PJRT plugins to load from env var. Args: - plugins_from_env: plugin name and pathes from env var. It is in the format + plugins_from_env: plugin name and paths from env var. It is in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows). Returns: diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index a28411fda..81276da90 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -287,7 +287,7 @@ async def async_deserialize( # transfer instead of loading data. In the future, if memory pressure # becomes a problem, we can instead instrument bytelimiter to # keep track of all in-flight tensors and only block_until_ready, if byte - # limiter hits the limit to get reduced memory usage, without loosing + # limiter hits the limit to get reduced memory usage, without losing # performance in common use cases. await byte_limiter.release_bytes(requested_bytes) return result diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index ef0cfa8a2..e0dcb434b 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -991,7 +991,7 @@ def _call_exported_abstract_eval(*in_avals: core.AbstractValue, # We discharge all the constraints statically. This results in much simpler # composability (because we do not have to worry about the constraints of the # Exported called recursively; we only need to worry about entry-point - # constraints). This also makes sense from a composibility point of view, + # constraints). This also makes sense from a composability point of view, # because we get the same errors if we invoke the exported module, or if we # trace the exported function. Consider for example, an exported module with # signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index a66497b74..d41c7d496 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -25,7 +25,7 @@ This enables many JAX programs to be traced with symbolic dimensions in some dimensions. A priority has been to enable the batch dimension in neural network examples to be polymorphic. -This was built initially for jax2tf, but it is now customizeable to be +This was built initially for jax2tf, but it is now customizable to be independent of TF. The best documentation at the moment is in the jax2tf.convert docstring, and the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). @@ -539,7 +539,7 @@ class _DimExpr(): try: power = int(power) except: - raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") + raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") return functools.reduce(op.mul, [self] * power) def __floordiv__(self, divisor): @@ -1269,7 +1269,7 @@ class ShapeConstraint: """Forms the error_message and error message_inputs. See shape_assertion. """ - # There is currenly a limitation in the shape assertion checker that + # There is currently a limitation in the shape assertion checker that # it supports at most 32 error_message_inputs. We try to stay within the # limit, reusing a format specifier if possible. if jaxlib_version <= (0, 4, 14): diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 4a40ad686..aa80afe9b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -573,7 +573,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" " for alternatives. Please file a feature request at " "https://github.com/google/jax/issues if none of the alternatives are " - "sufficent.") + "sufficient.") xops = xla_client._xla.ops diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index b4ddb7491..2648498ac 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -151,7 +151,7 @@ and to avoid warnings and outright errors. You can serialize JAX program into a TensorFlow SavedModel, for use with tooling that understands SavedModel. Both in native and non-native -serialization you can count on 6 months of backwards compatiblity (you +serialization you can count on 6 months of backwards compatibility (you can load a function serialized today with tooling that will be built up to 6 months in the future), and 3 weeks of limited forwards compatibility (you can load a function serialized today with tooling that was built @@ -1165,7 +1165,7 @@ There is work underway to enable more tools to consume StableHLO. Applies to native serialization only. -When you use native serialization, JAX will record the plaform for +When you use native serialization, JAX will record the platform for which the module was serialized, and you will get an error if you try to execute the `XlaCallModule` TensorFlow op on another platform. @@ -1180,7 +1180,7 @@ The current platform CPU is not among the platforms required by the module [CUDA ``` where `CPU` is the TensorFlow platform where the op is being executed -and `CUDA` is the plaform for which the module was serialized by JAX. +and `CUDA` is the platform for which the module was serialized by JAX. This probably means that JAX and TensorFlow may see different devices as the default device (JAX defaults to GPU and TensorFlow to CPU in the example error above). diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 63517a5ca..53823309d 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -83,7 +83,7 @@ def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> tuple[T def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape, rhs, rhs_shape: core.Shape, dimension_numbers): - """Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions. + """Transposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions. The shapes passed in and returned may contain polynomials, and thus may be different than lhs.shape and rhs.shape. @@ -91,7 +91,7 @@ def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape, # TODO(marcvanzee): Add tests for this ops for shape polymorphism. lhs_perm, rhs_perm, _ = dimension_numbers - # TODO(marcvanzee): Consider merging tranposes if we want to optimize. + # TODO(marcvanzee): Consider merging transposes if we want to optimize. # For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW". lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW" if len(lhs_perm) == 3: @@ -234,7 +234,7 @@ def _validate_conv_features( elif [is_depthwise, is_atrous, is_transpose].count(True) > 1: raise _conv_error( f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) " - f"and tranposed convolutions ({is_transpose})") + f"and transposed convolutions ({is_transpose})") # We can implement batch grouping when there is a need for it. if batch_group_count != 1: @@ -290,7 +290,7 @@ def _conv_general_dilated( else: padding_type = pads_to_padtype( lhs_shape[1:3], rhs_dilated_shape, window_strides, padding) - # We only manually pad if we aren't using a tranposed convolutions. + # We only manually pad if we aren't using a transposed convolutions. if padding_type == "EXPLICIT": lhs, lhs_shape, padding = _check_pad_spatial_dims(lhs, lhs_shape, padding) padding_type = padding @@ -716,7 +716,7 @@ def _reduce_window(*args, jaxpr, consts, window_dimensions, }[computation_name] result = reduce_fn(result, init_value) - # The outut is expected to be wrapped in a tuple, and since we don't use + # The output is expected to be wrapped in a tuple, and since we don't use # variadic reductions, this tuple always contains a single element. return (result,) @@ -933,7 +933,7 @@ def _pre_gather_with_batch_dims(args: GatherArgs): # also need to re-work the output reshaping raise ValueError("only len(collapsed_slice_dims) == 0 is supported") - # NOTE: This supports higher dimensions than listed (the highest dimenison + # NOTE: This supports higher dimensions than listed (the highest dimension # in the tests is 3D so it is limited to that, but the implementation is # designed to handle higher dimensions (N-Dimensional)). if len(args.batch_dims) not in [1, 2, 3]: @@ -1209,8 +1209,8 @@ def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None): Wrapper around the scatter function. The underlying tf ops `tf.tensor_scatter_nd_update` and `tf.math.unsorted_segment_*` index from the front dimensions. - `tf.math.unsorted_segment_*` indexs to a depth 1 from the front. - `tf.tensor_scatter_nd_update` indexs from the front dimensions onwards, + `tf.math.unsorted_segment_*` indexes to a depth 1 from the front. + `tf.tensor_scatter_nd_update` indexes from the front dimensions onwards, with no ability to skip a dimension. This function shifts the axes to be indexed to the front then calls a front-specific implementation, then inverse-shifts the output. diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 1beaeed4f..8a7d65004 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1723,7 +1723,7 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool, def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal: # For reducers we will need min/max for scalars only. In that case we - # can construct the AbstractValues outselves, even in the presence of + # can construct the AbstractValues ourselves, even in the presence of # shape polymorphism. assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}" aval = core.ShapedArray((), _to_jax_dtype(x.dtype)) diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index ae3a86e27..67f3ce690 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -50,7 +50,7 @@ def serialize_directory(directory_path): def deserialize_directory(serialized_string, output_directory): - """Deserialize the string to the diretory.""" + """Deserialize the string to the directory.""" # Convert the base64-encoded string back to binary data tar_data = base64.b64decode(serialized_string) @@ -71,7 +71,7 @@ class CompatTensoflowTest(bctu.CompatTestBase): # Here we use tf.saved_model and provide string serialize/deserialize methods # for the whole directory. @tf.function(autograph=False, jit_compile=True) - def tf_func(the_input): # Use recognizeable names for input and result + def tf_func(the_input): # Use recognizable names for input and result res = jax2tf.convert(func, native_serialization=True)(the_input) return tf.identity(res, name="the_result") diff --git a/jax/experimental/jax2tf/tests/converters.py b/jax/experimental/jax2tf/tests/converters.py index 4aae5ce42..f0a293ca5 100644 --- a/jax/experimental/jax2tf/tests/converters.py +++ b/jax/experimental/jax2tf/tests/converters.py @@ -96,7 +96,7 @@ ALL_CONVERTERS = [ Converter(name='jax2tfjs', convert_fn=jax2tfjs, compare_numerics=False), # Convert JAX to TFLIte. Converter(name='jax2tflite', convert_fn=jax2tflite), - # Convert JAX to TFLIte with suppor for Flex ops. + # Convert JAX to TFLIte with support for Flex ops. Converter( name='jax2tflite+flex', convert_fn=functools.partial(jax2tflite, use_flex_ops=True)) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 38aaf9446..5880b3175 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -148,7 +148,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # TODO(b/264596006): custom calls are not registered properly with TF in OSS if (config.jax2tf_default_native_serialization and "does not work with custom calls" in str(e)): - logging.warning("Supressing error %s", e) + logging.warning("Suppressing error %s", e) raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF") else: raise e diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index ea9b0130d..abe9539d0 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -512,7 +512,7 @@ class PolyHarness(Harness): """Args: group_name, name: The name for the harness. See `Harness.__init__`. - fun: the function to be converted, possbily after partial application to + fun: the function to be converted, possibly after partial application to static arguments from `arg_descriptors`. See `Harness.__init__`. arg_descriptors: The argument descriptors. See `Harness.__init__`. May be missing, in which case `skip_jax_run` should be `True` and diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index a95a31c7e..a818659f5 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -17,7 +17,7 @@ This module provides experimental support to CUDNN-backed LSTM. -Currrently, the only supported RNN flavor is LSTM with double-bias. We use +Currently, the only supported RNN flavor is LSTM with double-bias. We use notations and variable names similar to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 10bc58f72..ed42031e8 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -787,7 +787,7 @@ class BCSR(JAXSparse): return repr_ def transpose(self, *args, **kwargs): - raise NotImplementedError("Tranpose is not implemented.") + raise NotImplementedError("Transpose is not implemented.") def tree_flatten(self): return (self.data, self.indices, self.indptr), self._info._asdict() diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index 171e25df8..08aa88b30 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -199,7 +199,7 @@ def _lobpcg_standard_callable( # I tried many variants of hard and soft locking [3]. All of them seemed # to worsen performance relative to no locking. # - # Further, I found a more expermental convergence formula compared to what + # Further, I found a more experimental convergence formula compared to what # is suggested in the literature, loosely based on floating-point # expectations. # @@ -451,7 +451,7 @@ def _rayleigh_ritz_orth(A, S): SAS = _mm(S.T, A(S)) - # Solve the projected subsytem. + # Solve the projected subsystem. # If we could tell to eigh to stop after first k, we would. return _eigh_ascending(SAS) diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index 54f60d9eb..b22d466dc 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -32,7 +32,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None, key : random.PRNGKey to be passed to ``generator`` function. shape : tuple specifying the shape of the array to be generated. dtype : dtype of the array to be generated. - indices_dtype: dtype of the BCOO indicies. + indices_dtype: dtype of the BCOO indices. nse : number of specified elements in the matrix, or if 0 < nse < 1, a fraction of sparse dimensions to be specified (default: 0.2). n_batch : number of batch dimensions. must satisfy ``n_batch >= 0`` and diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 2da0bf3bc..5655b20aa 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -49,7 +49,7 @@ class SparseInfo(NamedTuple): #-------------------------------------------------------------------- # utilities -# TODO: possibly make these primitives, targeting cusparse rountines +# TODO: possibly make these primitives, targeting cusparse routines # csr2coo/coo2csr/SPDDMM def nfold_vmap(fun, N, *, broadcasted=True, in_axes=0): diff --git a/jax/version.py b/jax/version.py index 450adb64a..50a59a428 100644 --- a/jax/version.py +++ b/jax/version.py @@ -90,7 +90,7 @@ def _write_version(fname: str) -> None: new_version_string = f"_release_version: str = {release_version!r}" fhandle = pathlib.Path(fname) contents = fhandle.read_text() - # Expect two occurrances: one above, and one here. + # Expect two occurrences: one above, and one here. if contents.count(old_version_string) != 2: raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}") contents = contents.replace(old_version_string, new_version_string) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 3512362a0..9ffe48a15 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -648,7 +648,6 @@ class TestPromotionTables(jtu.JaxTestCase): numpy_dtype_promotion=['strict', 'standard'] ) def testSafeToCast(self, input_dtype, output_dtype, numpy_dtype_promotion): - print(input_dtype, output_dtype) with jax.numpy_dtype_promotion(numpy_dtype_promotion): # First the special cases which are always safe: always_safe = ( diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 723ec6671..1d8bd71c2 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -322,7 +322,7 @@ class ShardMapTest(jtu.JaxTestCase): def test_outer_jit_detects_shard_map_mesh(self): mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) - _ = jax.jit(f)(jnp.array(2.0)) # doesnt crash + _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))