mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix Typos
This commit is contained in:
parent
6d35113686
commit
13774d1382
@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package.
|
||||
## JAX arrays ({class}`jax.Array`)
|
||||
|
||||
The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to
|
||||
the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it
|
||||
the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it
|
||||
has some important differences.
|
||||
|
||||
### Array creation
|
||||
|
||||
We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions.
|
||||
For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality
|
||||
For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality
|
||||
such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
|
||||
|
||||
```{code-cell}
|
||||
@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`.
|
||||
## Pytrees
|
||||
|
||||
JAX functions and transformations fundamentally operate on arrays, but in practice it is
|
||||
convenient to write code that work with collections of arrays: for example, a neural
|
||||
convenient to write code that works with collection of arrays: for example, a neural
|
||||
network might organize its parameters in a dictionary of arrays with meaningful keys.
|
||||
Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree`
|
||||
abstraction to treat such collections in a uniform matter.
|
||||
abstraction to treat such collections in a uniform manner.
|
||||
|
||||
Here are some examples of objects that can be treated as pytrees:
|
||||
|
||||
|
@ -16,7 +16,7 @@ kernelspec:
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-13' } *-->
|
||||
|
||||
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
|
||||
**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
|
||||
|
||||
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
|
||||
|
||||
@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call
|
||||
%timeit selu_jit(x).block_until_ready()
|
||||
```
|
||||
|
||||
The above timing represent execution on CPU, but the same code can be run on GPU or TPU,
|
||||
typically for an even greater speedup.
|
||||
The above timing represents execution on CPU, but the same code can be run on GPU or
|
||||
TPU, typically for an even greater speedup.
|
||||
|
||||
For more on JIT compilation in JAX, check out {ref}`jit-compilation`.
|
||||
|
||||
@ -183,7 +183,7 @@ print('Naively batched')
|
||||
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
|
||||
```
|
||||
|
||||
A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can
|
||||
A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can
|
||||
be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`:
|
||||
|
||||
```{code-cell}
|
||||
|
Loading…
x
Reference in New Issue
Block a user