diff --git a/docs/key-concepts.md b/docs/key-concepts.md index b87808d14..daab2c9fd 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package. ## JAX arrays ({class}`jax.Array`) The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to -the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it +the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it has some important differences. ### Array creation We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions. -For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality +For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. ```{code-cell} @@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`. ## Pytrees JAX functions and transformations fundamentally operate on arrays, but in practice it is -convenient to write code that work with collections of arrays: for example, a neural +convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree` -abstraction to treat such collections in a uniform matter. +abstraction to treat such collections in a uniform manner. Here are some examples of objects that can be treated as pytrees: diff --git a/docs/quickstart.md b/docs/quickstart.md index e19cb33ea..77cbb9d46 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -16,7 +16,7 @@ kernelspec: -**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. +**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: @@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call %timeit selu_jit(x).block_until_ready() ``` -The above timing represent execution on CPU, but the same code can be run on GPU or TPU, -typically for an even greater speedup. +The above timing represents execution on CPU, but the same code can be run on GPU or +TPU, typically for an even greater speedup. For more on JIT compilation in JAX, check out {ref}`jit-compilation`. @@ -183,7 +183,7 @@ print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready() ``` -A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can +A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: ```{code-cell}