Fix Typos

This commit is contained in:
rajasekharporeddy 2024-09-25 21:26:05 +05:30
parent 6d35113686
commit 13774d1382
2 changed files with 8 additions and 8 deletions

View File

@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package.
## JAX arrays ({class}`jax.Array`)
The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to
the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it
the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it
has some important differences.
### Array creation
We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions.
For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality
For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality
such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
```{code-cell}
@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`.
## Pytrees
JAX functions and transformations fundamentally operate on arrays, but in practice it is
convenient to write code that work with collections of arrays: for example, a neural
convenient to write code that works with collection of arrays: for example, a neural
network might organize its parameters in a dictionary of arrays with meaningful keys.
Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree`
abstraction to treat such collections in a uniform matter.
abstraction to treat such collections in a uniform manner.
Here are some examples of objects that can be treated as pytrees:

View File

@ -16,7 +16,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-06-13' } *-->
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()
```
The above timing represent execution on CPU, but the same code can be run on GPU or TPU,
typically for an even greater speedup.
The above timing represents execution on CPU, but the same code can be run on GPU or
TPU, typically for an even greater speedup.
For more on JIT compilation in JAX, check out {ref}`jit-compilation`.
@ -183,7 +183,7 @@ print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
```
A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can
A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can
be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`:
```{code-cell}