mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix typos across the package
This commit is contained in:
parent
6dcda12140
commit
4a5bd9e046
@ -304,7 +304,7 @@ treated as dynamic. Here's how it might look::
|
||||
... CustomClass._tree_unflatten)
|
||||
|
||||
This is certainly more involved, but it solves all the issues associated with the simpler
|
||||
apporaches used above::
|
||||
approaches used above::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> print(c.calc(3))
|
||||
@ -442,7 +442,7 @@ When run with a GPU in Colab_, we see:
|
||||
- JAX takes 193 ms to compile the function
|
||||
- JAX takes 485 µs per evaluation on the GPU
|
||||
|
||||
In this case, we see that once the data is transfered and the function is
|
||||
In this case, we see that once the data is transferred and the function is
|
||||
compiled, JAX on the GPU is about 30x faster for repeated evaluations.
|
||||
|
||||
Is this a fair comparison? Maybe. The performance that ultimately matters is for
|
||||
|
@ -620,7 +620,7 @@
|
||||
"\n",
|
||||
"Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.\n",
|
||||
"\n",
|
||||
"We will explain other places where the JAX idiosyncracies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs."
|
||||
"We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -292,7 +292,7 @@ Isn't the pure version less efficient? Strictly, yes; we are creating a new arra
|
||||
|
||||
Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.
|
||||
|
||||
We will explain other places where the JAX idiosyncracies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs.
|
||||
We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs.
|
||||
|
||||
+++ {"id": "dFn_VBFFlGCz"}
|
||||
|
||||
|
@ -392,7 +392,7 @@
|
||||
"\n",
|
||||
"* `SequenceKey(idx: int)`: for lists and tuples.\n",
|
||||
"* `DictKey(key: Hashable)`: for dictionaries.\n",
|
||||
"* `GetAttrKey(name: str)`: for `namedtuple`s and preferrably custom pytree nodes (more in the next section)\n",
|
||||
"* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)\n",
|
||||
"\n",
|
||||
"You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression."
|
||||
]
|
||||
|
@ -233,7 +233,7 @@ To express key paths, JAX provides a few default key types for the built-in pytr
|
||||
|
||||
* `SequenceKey(idx: int)`: for lists and tuples.
|
||||
* `DictKey(key: Hashable)`: for dictionaries.
|
||||
* `GetAttrKey(name: str)`: for `namedtuple`s and preferrably custom pytree nodes (more in the next section)
|
||||
* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)
|
||||
|
||||
You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression.
|
||||
|
||||
|
@ -166,7 +166,7 @@ Inputs to JAX functions and methods should be typed as permissively as is reason
|
||||
|
||||
- `ArrayLike` would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalars
|
||||
- `DTypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.
|
||||
- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objecs.
|
||||
- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objects.
|
||||
- etc.
|
||||
|
||||
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DTypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
|
||||
@ -177,7 +177,7 @@ Conversely, outputs of functions and methods should be typed as strictly as poss
|
||||
|
||||
- `Array` or `NDArray` (see below) for type annotation purposes is effectively equivalent to `Union[Tracer, jnp.ndarray]` and should be used to annotate array outputs.
|
||||
- `DType` is an alias of `np.dtype`, perhaps with the ability to also represent key types and other generalizations used within JAX.
|
||||
- `Shape` is essentially `Tuple[int, ...]`, perhaps with some additional flexibilty to account for dynamic shapes.
|
||||
- `Shape` is essentially `Tuple[int, ...]`, perhaps with some additional flexibility to account for dynamic shapes.
|
||||
- `NamedShape` is an extension of `Shape` that allows for named shapes as used internall in JAX.
|
||||
- etc.
|
||||
|
||||
@ -283,7 +283,7 @@ Finally, we could opt for full unification via restructuring of the class hierar
|
||||
Here `jnp.ndarray` could be an alias for `jax.Array`.
|
||||
This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (`Tracer` *is an* `Array`?).
|
||||
|
||||
##### Option 4: Parial unification via class hierarchy
|
||||
##### Option 4: Partial unification via class hierarchy
|
||||
We could make the class hierarchy more sensible by making `Tracer` and the class for
|
||||
on-device arrays inherit from a common base class. So, for example:
|
||||
|
||||
|
@ -505,7 +505,7 @@ than an `int`).
|
||||
Here are reasons we like unmapped inputs and outputs for `shmap`:
|
||||
* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape
|
||||
hatch should be able to do too. Or else we'd have a lacking escape hatch! If
|
||||
we didnt have unmapped outputs in `shmap` then we couldn't express the same
|
||||
we didn't have unmapped outputs in `shmap` then we couldn't express the same
|
||||
batch-parallel loss function computations as `pjit`.
|
||||
* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped
|
||||
inputs, and...
|
||||
|
@ -946,7 +946,7 @@
|
||||
"source": [
|
||||
"### How to handle `uint64`?\n",
|
||||
"\n",
|
||||
"The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer opertion involving `uint64` should result in `int128`, but this is not a standard available dtype.\n",
|
||||
"The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.\n",
|
||||
"\n",
|
||||
"Numpy's choice here is to promote to `float64`:"
|
||||
]
|
||||
@ -1919,7 +1919,7 @@
|
||||
"source": [
|
||||
"### Tensorflow Type Promotion\n",
|
||||
"\n",
|
||||
"Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercable to the type of `x`."
|
||||
"Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercible to the type of `x`."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -457,7 +457,7 @@ Again, the connections added here are precisely the promotion semantics implemen
|
||||
|
||||
### How to handle `uint64`?
|
||||
|
||||
The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer opertion involving `uint64` should result in `int128`, but this is not a standard available dtype.
|
||||
The approached to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.
|
||||
|
||||
Numpy's choice here is to promote to `float64`:
|
||||
|
||||
@ -796,7 +796,7 @@ display.HTML(table.to_html())
|
||||
|
||||
### Tensorflow Type Promotion
|
||||
|
||||
Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercable to the type of `x`.
|
||||
Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercible to the type of `x`.
|
||||
|
||||
```{code-cell}
|
||||
:cellView: form
|
||||
|
@ -203,7 +203,7 @@
|
||||
"id": "40oy-FbmVkDc"
|
||||
},
|
||||
"source": [
|
||||
"When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility definied in this notebook:"
|
||||
"When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility defined in this notebook:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -359,7 +359,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In both `jax.linearize` and `jax.vjp` there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.\n",
|
||||
"In both `jax.linearize` and `jax.vjp` there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.\n",
|
||||
"\n",
|
||||
"One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`:"
|
||||
]
|
||||
|
@ -99,7 +99,7 @@ jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
|
||||
|
||||
+++ {"id": "40oy-FbmVkDc"}
|
||||
|
||||
When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility definied in this notebook:
|
||||
When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility defined in this notebook:
|
||||
|
||||
```{code-cell}
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
@ -162,7 +162,7 @@ You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.read
|
||||
|
||||
|
||||
|
||||
In both `jax.linearize` and `jax.vjp` there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.
|
||||
In both `jax.linearize` and `jax.vjp` there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.
|
||||
|
||||
One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`:
|
||||
|
||||
|
@ -129,7 +129,7 @@
|
||||
"ax[0].imshow(image, cmap='binary_r')\n",
|
||||
"ax[0].set_title('original')\n",
|
||||
"\n",
|
||||
"# Create a noisy version by adding random Gausian noise\n",
|
||||
"# Create a noisy version by adding random Gaussian noise\n",
|
||||
"key = random.PRNGKey(1701)\n",
|
||||
"noisy_image = image + 50 * random.normal(key, image.shape)\n",
|
||||
"ax[1].imshow(noisy_image, cmap='binary_r')\n",
|
||||
@ -722,7 +722,7 @@
|
||||
"# N,H,W,C = img.shape\n",
|
||||
"# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))\n",
|
||||
"\n",
|
||||
"# transposed conv = 180deg kernel roation plus LHS dilation\n",
|
||||
"# transposed conv = 180deg kernel rotation plus LHS dilation\n",
|
||||
"# rotate kernel 180deg:\n",
|
||||
"kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))\n",
|
||||
"# need a custom output padding:\n",
|
||||
|
@ -83,7 +83,7 @@ image = jnp.array(misc.face().mean(-1))
|
||||
ax[0].imshow(image, cmap='binary_r')
|
||||
ax[0].set_title('original')
|
||||
|
||||
# Create a noisy version by adding random Gausian noise
|
||||
# Create a noisy version by adding random Gaussian noise
|
||||
key = random.PRNGKey(1701)
|
||||
noisy_image = image + 50 * random.normal(key, image.shape)
|
||||
ax[1].imshow(noisy_image, cmap='binary_r')
|
||||
@ -330,7 +330,7 @@ We can use the last to, for instance, implement _transposed convolutions_:
|
||||
# N,H,W,C = img.shape
|
||||
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))
|
||||
|
||||
# transposed conv = 180deg kernel roation plus LHS dilation
|
||||
# transposed conv = 180deg kernel rotation plus LHS dilation
|
||||
# rotate kernel 180deg:
|
||||
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
|
||||
# need a custom output padding:
|
||||
|
@ -333,7 +333,7 @@
|
||||
"id": "LrvdAloMZbIe"
|
||||
},
|
||||
"source": [
|
||||
"By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may elliminate the callback entirely:"
|
||||
"By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -398,7 +398,7 @@
|
||||
"\n",
|
||||
"In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n",
|
||||
"\n",
|
||||
"As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generaing a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)."
|
||||
"As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -171,7 +171,7 @@ For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pu
|
||||
|
||||
+++ {"id": "LrvdAloMZbIe"}
|
||||
|
||||
By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may elliminate the callback entirely:
|
||||
By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
|
||||
|
||||
```{code-cell}
|
||||
:id: mmFc_zawZrBq
|
||||
@ -208,7 +208,7 @@ In `f2` on the other hand, the output of the callback is unused, and so the comp
|
||||
|
||||
In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
|
||||
|
||||
As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generaing a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).
|
||||
As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).
|
||||
|
||||
```{code-cell}
|
||||
:id: eAg5xIhrOiWV
|
||||
|
@ -173,7 +173,7 @@
|
||||
"A 2D grid\n",
|
||||
"</center>\n",
|
||||
"\n",
|
||||
"When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invokations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invokation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n",
|
||||
"When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a \"program\", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.\n",
|
||||
"\n",
|
||||
"Here's an example kernel that uses a `grid` and `program_id`."
|
||||
]
|
||||
|
@ -112,7 +112,7 @@ We run the kernel function once for each element, a style of single-program mult
|
||||
A 2D grid
|
||||
</center>
|
||||
|
||||
When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invokations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invokation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
|
||||
When we provide a `grid` to `pallas_call`, the kernel is executed as many times as `prod(grid)`. Each of these invocations is referred to as a "program", To access which program (i.e. which element of the grid) the kernel is currently executing, we use `program_id(axis=...)`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
|
||||
|
||||
Here's an example kernel that uses a `grid` and `program_id`.
|
||||
|
||||
|
@ -86,7 +86,7 @@ with compute.
|
||||
What's more, compared to GPUs, TPUs are actually highly sequential machines.
|
||||
That's why, the grid is generally not processed in parallel, but sequentially,
|
||||
in lexicographic order (though see the `Multicore TPU configurations`_ section
|
||||
for exceptions). This unlocks some interesting capabilites:
|
||||
for exceptions). This unlocks some interesting capabilities:
|
||||
|
||||
* When two (lexicographically) consecutive grid indices use the same slice of
|
||||
an input, the HBM transfer in the second iteration is skipped, as the data is
|
||||
|
@ -197,7 +197,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj):
|
||||
# be part of the cache key as their inclusion will result in unnecessary cache
|
||||
# misses. Clear them here by setting bool values to False, ints to 0, and
|
||||
# strings to empty. The exact values used to clear are not relevant as long
|
||||
# as the same values are used everytime for each field.
|
||||
# as the same values are used every time for each field.
|
||||
debug_options = compile_options_copy.executable_build_options.debug_options
|
||||
# LINT.IfChange(debug_options)
|
||||
debug_options.xla_force_host_platform_device_count = 0
|
||||
|
@ -82,7 +82,7 @@ class ClusterEnv:
|
||||
"""Returns address and port used by JAX to bootstrap.
|
||||
|
||||
Process id 0 will open a tcp socket at "hostname:port" where
|
||||
all the proccesses will connect to initialize the distributed JAX service.
|
||||
all the processes will connect to initialize the distributed JAX service.
|
||||
The selected port needs to be free.
|
||||
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
|
||||
|
||||
|
@ -801,7 +801,7 @@ spmd_mode = config.define_enum_state(
|
||||
" to execute on non-fully addressable `jax.Array`s\n"
|
||||
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
|
||||
" `jax.jit` and all other operations are allowed to "
|
||||
" execute on non-fully addresable `jax.Array`s."))
|
||||
" execute on non-fully addressable `jax.Array`s."))
|
||||
|
||||
|
||||
distributed_debug = config.define_bool_state(
|
||||
|
@ -1198,7 +1198,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
Args:
|
||||
fun: a Python callable specifying a linear function. It should
|
||||
take two arguments: one of "residual" inputs (type ``r``),
|
||||
i.e. inputs in which the function is not necessarly linear, and
|
||||
i.e. inputs in which the function is not necessarily linear, and
|
||||
one of "linear" inputs (type ``a``). It should return output
|
||||
whose components are linear in the linear input (type ``b``).
|
||||
fun_transpose: a Python callable specifying a structurally linear
|
||||
|
@ -165,7 +165,7 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
|
||||
Example:
|
||||
|
||||
Suppose there are two GPU processs, and process 0 is the designated coordinator
|
||||
Suppose there are two GPU processes, and process 0 is the designated coordinator
|
||||
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
|
||||
following commands before anything else.
|
||||
|
||||
|
@ -535,7 +535,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
JAX detects leaks when you then use the leaked value in another
|
||||
operation later on, at which point it raises an ``UnexpectedTracerError``.
|
||||
To fix this, avoid side effects: if a function computes a value needed
|
||||
in an outer scope, return that value from the transformed function explictly.
|
||||
in an outer scope, return that value from the transformed function explicitly.
|
||||
|
||||
Specifically, a ``Tracer`` is JAX's internal representation of a function's
|
||||
intermediate values during transformations, e.g. within :func:`~jax.jit`,
|
||||
|
@ -406,7 +406,7 @@ class BatchTrace(Trace):
|
||||
else:
|
||||
axis_size = None # can't be inferred from data
|
||||
if self.axis_name is core.no_axis_name:
|
||||
assert axis_size is not None # must be inferrable from data
|
||||
assert axis_size is not None # must be inferable from data
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
frame = core.axis_frame(self.axis_name, self.main)
|
||||
assert axis_size is None or axis_size == frame.size, (axis_size, frame.size)
|
||||
@ -1063,7 +1063,7 @@ def _mask_one_ragged_axis(
|
||||
ragged_axis, segment_lengths = axis_spec.ragged_axes[0]
|
||||
value = ident(operand.dtype)
|
||||
positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis)
|
||||
# TODO(mattjj, axch) cant get ._data, need to convert it
|
||||
# TODO(mattjj, axch) can't get ._data, need to convert it
|
||||
# lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32')
|
||||
lengths = jax.lax.convert_element_type(segment_lengths, 'int32')
|
||||
limits = jax.lax.broadcast_in_dim(
|
||||
|
@ -1058,7 +1058,7 @@ def partial_eval_jaxpr_nounits(
|
||||
In the above example, the output of jaxpr_known named `d` is a _residual_
|
||||
output, and corresponds to the input named `a` in jaxpr_unknown. In general,
|
||||
jaxpr_known will produce extra outputs (at the end of its output list)
|
||||
corresponding to intermeidate values of the original jaxpr which must be
|
||||
corresponding to intermediate values of the original jaxpr which must be
|
||||
passed to jaxpr_unknown (as leading inputs).
|
||||
"""
|
||||
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
||||
@ -2411,7 +2411,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
|
||||
invars = [*jaxpr.constvars, *jaxpr.invars]
|
||||
expl_outvars = jaxpr.outvars
|
||||
|
||||
# First do a pass to collect implicit outputs, meaning variables which occurr
|
||||
# First do a pass to collect implicit outputs, meaning variables which occur
|
||||
# in explicit_outvars types but not in invars or to the left in outvars.
|
||||
seen: set[Var] = set(invars)
|
||||
impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore
|
||||
|
@ -2856,7 +2856,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
in_shardings = [gs] * num_in_shardings
|
||||
out_shardings = [gs] * num_out_shardings
|
||||
# jit(pmap) will generate Arrays with multi-device sharding.
|
||||
# It is unsupported for these shardings to be uncommited, so force
|
||||
# It is unsupported for these shardings to be uncommitted, so force
|
||||
# the outputs to be committed.
|
||||
committed = True
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
@ -203,7 +203,7 @@ def approx_min_k(operand: Array,
|
||||
|
||||
In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
|
||||
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
|
||||
arithmetics and produces the same set of neighbors.
|
||||
arithmetic and produces the same set of neighbors.
|
||||
"""
|
||||
return approx_top_k_p.bind(
|
||||
operand,
|
||||
|
@ -719,7 +719,7 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
|
||||
|
||||
def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||||
# if any in_ct is nonzero, we definitely want it in args_ (and the
|
||||
# corresponding x in args could be an undefined primal, but doesnt have to be)
|
||||
# corresponding x in args could be an undefined primal, but doesn't have to be)
|
||||
# for non-res stuff:
|
||||
# getting and setting => (nonzero ct, UndefinedPrimal arg)
|
||||
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
|
||||
|
@ -3286,7 +3286,7 @@ mlir.register_lowering(pad_p, _pad_lower)
|
||||
# For N > 1, we can match up the output array axis with the second axis of the
|
||||
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
||||
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
||||
# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
||||
# In contrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
||||
|
||||
|
||||
def _squeeze_dtype_rule(operand, *, dimensions):
|
||||
|
@ -2038,7 +2038,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array:
|
||||
The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[0] = 0``.
|
||||
d: A batch of vectors with shape ``[..., m]``.
|
||||
The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
||||
The middle diagonal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
||||
du: A batch of vectors with shape ``[..., m]``.
|
||||
The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[m - 1] = 0``.
|
||||
|
@ -605,7 +605,7 @@ def bessel_i0e_impl(x):
|
||||
elif x.dtype == np.float32:
|
||||
return _i0e_impl32(x)
|
||||
else:
|
||||
# Have to upcast f16 because the magic Cephes coefficents don't have enough
|
||||
# Have to upcast f16 because the magic Cephes coefficients don't have enough
|
||||
# precision for it.
|
||||
x_dtype = x.dtype
|
||||
x = x.astype(np.float32)
|
||||
|
@ -55,7 +55,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
|
||||
"""Copy the array and cast to a specified dtype.
|
||||
|
||||
This is implemeted via :func:`jax.lax.convert_element_type`, which may
|
||||
This is implemented via :func:`jax.lax.convert_element_type`, which may
|
||||
have slightly different behavior than :meth:`numpy.ndarray.astype` in
|
||||
some cases. In particular, the details of float-to-int and int-to-float
|
||||
casts are implementation dependent.
|
||||
@ -346,7 +346,7 @@ class _IndexUpdateHelper:
|
||||
"""Helper property for index update functionality.
|
||||
|
||||
The ``at`` property provides a functionally pure equivalent of in-place
|
||||
array modificatons.
|
||||
array modifications.
|
||||
|
||||
In particular:
|
||||
|
||||
|
@ -1840,7 +1840,7 @@ def with_sharding_constraint(x, shardings):
|
||||
of how to use this function, see `Distributed arrays and automatic parallelization`_.
|
||||
|
||||
Args:
|
||||
x: PyTree of jax.Arrays which will have their shardings constrainted
|
||||
x: PyTree of jax.Arrays which will have their shardings constrained
|
||||
shardings: PyTree of sharding specifications. Valid values are the same as for
|
||||
the ``in_shardings`` argument of :func:`jax.experimental.pjit`.
|
||||
Returns:
|
||||
|
@ -104,7 +104,7 @@ def start_trace(log_dir, create_perfetto_link: bool = False,
|
||||
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
|
||||
generated if ``create_perfetto_link`` is true. This could be useful if you
|
||||
want to generate a Perfetto-compatible trace without blocking the
|
||||
processs.
|
||||
process.
|
||||
"""
|
||||
with _profile_state.lock:
|
||||
if _profile_state.profile_session is not None:
|
||||
@ -228,7 +228,7 @@ def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False):
|
||||
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
|
||||
generated if ``create_perfetto_link`` is true. This could be useful if you
|
||||
want to generate a Perfetto-compatible trace without blocking the
|
||||
processs.
|
||||
process.
|
||||
"""
|
||||
start_trace(log_dir, create_perfetto_link, create_perfetto_trace)
|
||||
try:
|
||||
|
@ -865,7 +865,7 @@ def _gen_derivatives(p: Array,
|
||||
|
||||
Args:
|
||||
p: The 3D array containing the values of associated Legendre functions; the
|
||||
dimensions are in the sequence of order (m), degree (l), and evalution
|
||||
dimensions are in the sequence of order (m), degree (l), and evaluation
|
||||
points.
|
||||
x: A vector of type `float32` or `float64` containing the sampled points.
|
||||
is_normalized: True if the associated Legendre functions are normalized.
|
||||
@ -962,7 +962,7 @@ def _gen_associated_legendre(l_max: int,
|
||||
harmonic of degree `l` and order `m` can be written as
|
||||
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
|
||||
normalization factor and θ and φ are the colatitude and longitude,
|
||||
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
|
||||
respectively. `N_l^m` is chosen in the way that the spherical harmonics form
|
||||
a set of orthonormal basis function of L^2(S^2). For the computational
|
||||
efficiency of spherical harmonics transform, the normalization factor is
|
||||
used in the computation of the ALFs. In addition, normalizing `P_l^m`
|
||||
@ -999,7 +999,7 @@ def _gen_associated_legendre(l_max: int,
|
||||
Returns:
|
||||
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
|
||||
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
|
||||
evalution points.
|
||||
evaluation points.
|
||||
"""
|
||||
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype)
|
||||
|
||||
@ -1106,7 +1106,7 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array:
|
||||
spherical harmonic of degree `l` and order `m` can be written as
|
||||
:math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
|
||||
where :math:`N_l^m` is the normalization factor and θ and φ are the
|
||||
colatitude and longitude, repectively. :math:`N_l^m` is chosen in the
|
||||
colatitude and longitude, respectively. :math:`N_l^m` is chosen in the
|
||||
way that the spherical harmonics form a set of orthonormal basis function
|
||||
of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
|
||||
and achieves better numerical stability.
|
||||
@ -1192,7 +1192,7 @@ def sph_harm(m: Array,
|
||||
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
|
||||
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
|
||||
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
|
||||
:math:`\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
|
||||
:math:`\theta` are the colatitude and longitude, respectively. :math:`N_n^m` is
|
||||
chosen in the way that the spherical harmonics form a set of orthonormal basis
|
||||
functions of :math:`L^2(S^2)`.
|
||||
|
||||
@ -1600,7 +1600,7 @@ def expn_jvp(n, primals, tangents):
|
||||
@_wraps(osp_special.exp1, module="scipy.special")
|
||||
def exp1(x: ArrayLike, module='scipy.special') -> Array:
|
||||
x, = promote_args_inexact("exp1", x)
|
||||
# Casting becuase custom_jvp generic does not work correctly with mypy.
|
||||
# Casting because custom_jvp generic does not work correctly with mypy.
|
||||
return cast(Array, expn(1, x))
|
||||
|
||||
|
||||
|
@ -477,7 +477,7 @@ class Compiled(Stage):
|
||||
# This is because `__call__` passes in `self._params` as the first argument.
|
||||
# Instead of making the call signature `call(params, *args, **kwargs)`
|
||||
# extract it from args because `params` can be passed as a kwarg by users
|
||||
# which might confict here.
|
||||
# which might conflict here.
|
||||
params = args[0]
|
||||
args = args[1:]
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
|
@ -629,7 +629,7 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
|
||||
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,
|
||||
which_linear: tuple[bool, ...]):
|
||||
# if any in_ct is nonzero, we definitely want it in args_ (and the
|
||||
# corresponding x in args could be an undefined primal, but doesnt have to be)
|
||||
# corresponding x in args could be an undefined primal, but doesn't have to be)
|
||||
# for non-res stuff:
|
||||
# getting and setting => (nonzero ct, UndefinedPrimal arg)
|
||||
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
|
||||
|
@ -216,7 +216,7 @@ def _get_pjrt_plugin_names_and_library_paths(
|
||||
"""Gets the names and library paths of PJRT plugins to load from env var.
|
||||
|
||||
Args:
|
||||
plugins_from_env: plugin name and pathes from env var. It is in the format
|
||||
plugins_from_env: plugin name and paths from env var. It is in the format
|
||||
of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows).
|
||||
|
||||
Returns:
|
||||
|
@ -287,7 +287,7 @@ async def async_deserialize(
|
||||
# transfer instead of loading data. In the future, if memory pressure
|
||||
# becomes a problem, we can instead instrument bytelimiter to
|
||||
# keep track of all in-flight tensors and only block_until_ready, if byte
|
||||
# limiter hits the limit to get reduced memory usage, without loosing
|
||||
# limiter hits the limit to get reduced memory usage, without losing
|
||||
# performance in common use cases.
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return result
|
||||
|
@ -991,7 +991,7 @@ def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
|
||||
# We discharge all the constraints statically. This results in much simpler
|
||||
# composability (because we do not have to worry about the constraints of the
|
||||
# Exported called recursively; we only need to worry about entry-point
|
||||
# constraints). This also makes sense from a composibility point of view,
|
||||
# constraints). This also makes sense from a composability point of view,
|
||||
# because we get the same errors if we invoke the exported module, or if we
|
||||
# trace the exported function. Consider for example, an exported module with
|
||||
# signature `f32[a, a] -> f32[a]`. If we invoke the module with an argument
|
||||
|
@ -25,7 +25,7 @@ This enables many JAX programs to be traced with symbolic dimensions
|
||||
in some dimensions. A priority has been to enable the batch
|
||||
dimension in neural network examples to be polymorphic.
|
||||
|
||||
This was built initially for jax2tf, but it is now customizeable to be
|
||||
This was built initially for jax2tf, but it is now customizable to be
|
||||
independent of TF. The best documentation at the moment is in the
|
||||
jax2tf.convert docstring, and the
|
||||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
||||
@ -539,7 +539,7 @@ class _DimExpr():
|
||||
try:
|
||||
power = int(power)
|
||||
except:
|
||||
raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
return functools.reduce(op.mul, [self] * power)
|
||||
|
||||
def __floordiv__(self, divisor):
|
||||
@ -1269,7 +1269,7 @@ class ShapeConstraint:
|
||||
"""Forms the error_message and error message_inputs.
|
||||
See shape_assertion.
|
||||
"""
|
||||
# There is currenly a limitation in the shape assertion checker that
|
||||
# There is currently a limitation in the shape assertion checker that
|
||||
# it supports at most 32 error_message_inputs. We try to stay within the
|
||||
# limit, reusing a format specifier if possible.
|
||||
if jaxlib_version <= (0, 4, 14):
|
||||
|
@ -573,7 +573,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
|
||||
"https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
|
||||
" for alternatives. Please file a feature request at "
|
||||
"https://github.com/google/jax/issues if none of the alternatives are "
|
||||
"sufficent.")
|
||||
"sufficient.")
|
||||
|
||||
|
||||
xops = xla_client._xla.ops
|
||||
|
@ -151,7 +151,7 @@ and to avoid warnings and outright errors.
|
||||
|
||||
You can serialize JAX program into a TensorFlow SavedModel, for use
|
||||
with tooling that understands SavedModel. Both in native and non-native
|
||||
serialization you can count on 6 months of backwards compatiblity (you
|
||||
serialization you can count on 6 months of backwards compatibility (you
|
||||
can load a function serialized today with tooling that will be built
|
||||
up to 6 months in the future), and 3 weeks of limited forwards compatibility
|
||||
(you can load a function serialized today with tooling that was built
|
||||
@ -1165,7 +1165,7 @@ There is work underway to enable more tools to consume StableHLO.
|
||||
|
||||
Applies to native serialization only.
|
||||
|
||||
When you use native serialization, JAX will record the plaform for
|
||||
When you use native serialization, JAX will record the platform for
|
||||
which the module was serialized, and you will get an error if you
|
||||
try to execute the `XlaCallModule` TensorFlow op on another platform.
|
||||
|
||||
@ -1180,7 +1180,7 @@ The current platform CPU is not among the platforms required by the module [CUDA
|
||||
```
|
||||
|
||||
where `CPU` is the TensorFlow platform where the op is being executed
|
||||
and `CUDA` is the plaform for which the module was serialized by JAX.
|
||||
and `CUDA` is the platform for which the module was serialized by JAX.
|
||||
This probably means that JAX and TensorFlow may see different devices
|
||||
as the default device (JAX defaults to GPU and TensorFlow to CPU
|
||||
in the example error above).
|
||||
|
@ -83,7 +83,7 @@ def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> tuple[T
|
||||
|
||||
def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
|
||||
rhs, rhs_shape: core.Shape, dimension_numbers):
|
||||
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
|
||||
"""Transposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
|
||||
|
||||
The shapes passed in and returned may contain polynomials, and thus may
|
||||
be different than lhs.shape and rhs.shape.
|
||||
@ -91,7 +91,7 @@ def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
|
||||
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
|
||||
lhs_perm, rhs_perm, _ = dimension_numbers
|
||||
|
||||
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
|
||||
# TODO(marcvanzee): Consider merging transposes if we want to optimize.
|
||||
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
|
||||
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW"
|
||||
if len(lhs_perm) == 3:
|
||||
@ -234,7 +234,7 @@ def _validate_conv_features(
|
||||
elif [is_depthwise, is_atrous, is_transpose].count(True) > 1:
|
||||
raise _conv_error(
|
||||
f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) "
|
||||
f"and tranposed convolutions ({is_transpose})")
|
||||
f"and transposed convolutions ({is_transpose})")
|
||||
|
||||
# We can implement batch grouping when there is a need for it.
|
||||
if batch_group_count != 1:
|
||||
@ -290,7 +290,7 @@ def _conv_general_dilated(
|
||||
else:
|
||||
padding_type = pads_to_padtype(
|
||||
lhs_shape[1:3], rhs_dilated_shape, window_strides, padding)
|
||||
# We only manually pad if we aren't using a tranposed convolutions.
|
||||
# We only manually pad if we aren't using a transposed convolutions.
|
||||
if padding_type == "EXPLICIT":
|
||||
lhs, lhs_shape, padding = _check_pad_spatial_dims(lhs, lhs_shape, padding)
|
||||
padding_type = padding
|
||||
@ -716,7 +716,7 @@ def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
||||
}[computation_name]
|
||||
result = reduce_fn(result, init_value)
|
||||
|
||||
# The outut is expected to be wrapped in a tuple, and since we don't use
|
||||
# The output is expected to be wrapped in a tuple, and since we don't use
|
||||
# variadic reductions, this tuple always contains a single element.
|
||||
return (result,)
|
||||
|
||||
@ -933,7 +933,7 @@ def _pre_gather_with_batch_dims(args: GatherArgs):
|
||||
# also need to re-work the output reshaping
|
||||
raise ValueError("only len(collapsed_slice_dims) == 0 is supported")
|
||||
|
||||
# NOTE: This supports higher dimensions than listed (the highest dimenison
|
||||
# NOTE: This supports higher dimensions than listed (the highest dimension
|
||||
# in the tests is 3D so it is limited to that, but the implementation is
|
||||
# designed to handle higher dimensions (N-Dimensional)).
|
||||
if len(args.batch_dims) not in [1, 2, 3]:
|
||||
@ -1209,8 +1209,8 @@ def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None):
|
||||
Wrapper around the scatter function.
|
||||
The underlying tf ops `tf.tensor_scatter_nd_update` and
|
||||
`tf.math.unsorted_segment_*` index from the front dimensions.
|
||||
`tf.math.unsorted_segment_*` indexs to a depth 1 from the front.
|
||||
`tf.tensor_scatter_nd_update` indexs from the front dimensions onwards,
|
||||
`tf.math.unsorted_segment_*` indexes to a depth 1 from the front.
|
||||
`tf.tensor_scatter_nd_update` indexes from the front dimensions onwards,
|
||||
with no ability to skip a dimension. This function shifts the axes to be
|
||||
indexed to the front then calls a front-specific implementation, then
|
||||
inverse-shifts the output.
|
||||
|
@ -1723,7 +1723,7 @@ def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
||||
|
||||
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
|
||||
# For reducers we will need min/max for scalars only. In that case we
|
||||
# can construct the AbstractValues outselves, even in the presence of
|
||||
# can construct the AbstractValues ourselves, even in the presence of
|
||||
# shape polymorphism.
|
||||
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
|
||||
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
|
||||
|
@ -50,7 +50,7 @@ def serialize_directory(directory_path):
|
||||
|
||||
|
||||
def deserialize_directory(serialized_string, output_directory):
|
||||
"""Deserialize the string to the diretory."""
|
||||
"""Deserialize the string to the directory."""
|
||||
# Convert the base64-encoded string back to binary data
|
||||
tar_data = base64.b64decode(serialized_string)
|
||||
|
||||
@ -71,7 +71,7 @@ class CompatTensoflowTest(bctu.CompatTestBase):
|
||||
# Here we use tf.saved_model and provide string serialize/deserialize methods
|
||||
# for the whole directory.
|
||||
@tf.function(autograph=False, jit_compile=True)
|
||||
def tf_func(the_input): # Use recognizeable names for input and result
|
||||
def tf_func(the_input): # Use recognizable names for input and result
|
||||
res = jax2tf.convert(func, native_serialization=True)(the_input)
|
||||
return tf.identity(res, name="the_result")
|
||||
|
||||
|
@ -96,7 +96,7 @@ ALL_CONVERTERS = [
|
||||
Converter(name='jax2tfjs', convert_fn=jax2tfjs, compare_numerics=False),
|
||||
# Convert JAX to TFLIte.
|
||||
Converter(name='jax2tflite', convert_fn=jax2tflite),
|
||||
# Convert JAX to TFLIte with suppor for Flex ops.
|
||||
# Convert JAX to TFLIte with support for Flex ops.
|
||||
Converter(
|
||||
name='jax2tflite+flex',
|
||||
convert_fn=functools.partial(jax2tflite, use_flex_ops=True))
|
||||
|
@ -148,7 +148,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# TODO(b/264596006): custom calls are not registered properly with TF in OSS
|
||||
if (config.jax2tf_default_native_serialization and
|
||||
"does not work with custom calls" in str(e)):
|
||||
logging.warning("Supressing error %s", e)
|
||||
logging.warning("Suppressing error %s", e)
|
||||
raise unittest.SkipTest("b/264596006: custom calls in native serialization fail in TF")
|
||||
else:
|
||||
raise e
|
||||
|
@ -512,7 +512,7 @@ class PolyHarness(Harness):
|
||||
"""Args:
|
||||
|
||||
group_name, name: The name for the harness. See `Harness.__init__`.
|
||||
fun: the function to be converted, possbily after partial application to
|
||||
fun: the function to be converted, possibly after partial application to
|
||||
static arguments from `arg_descriptors`. See `Harness.__init__`.
|
||||
arg_descriptors: The argument descriptors. See `Harness.__init__`. May
|
||||
be missing, in which case `skip_jax_run` should be `True` and
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
This module provides experimental support to CUDNN-backed LSTM.
|
||||
|
||||
Currrently, the only supported RNN flavor is LSTM with double-bias. We use
|
||||
Currently, the only supported RNN flavor is LSTM with double-bias. We use
|
||||
notations and variable names similar to
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
|
||||
|
||||
|
@ -787,7 +787,7 @@ class BCSR(JAXSparse):
|
||||
return repr_
|
||||
|
||||
def transpose(self, *args, **kwargs):
|
||||
raise NotImplementedError("Tranpose is not implemented.")
|
||||
raise NotImplementedError("Transpose is not implemented.")
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self.data, self.indices, self.indptr), self._info._asdict()
|
||||
|
@ -199,7 +199,7 @@ def _lobpcg_standard_callable(
|
||||
# I tried many variants of hard and soft locking [3]. All of them seemed
|
||||
# to worsen performance relative to no locking.
|
||||
#
|
||||
# Further, I found a more expermental convergence formula compared to what
|
||||
# Further, I found a more experimental convergence formula compared to what
|
||||
# is suggested in the literature, loosely based on floating-point
|
||||
# expectations.
|
||||
#
|
||||
@ -451,7 +451,7 @@ def _rayleigh_ritz_orth(A, S):
|
||||
|
||||
SAS = _mm(S.T, A(S))
|
||||
|
||||
# Solve the projected subsytem.
|
||||
# Solve the projected subsystem.
|
||||
# If we could tell to eigh to stop after first k, we would.
|
||||
return _eigh_ascending(SAS)
|
||||
|
||||
|
@ -32,7 +32,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None,
|
||||
key : random.PRNGKey to be passed to ``generator`` function.
|
||||
shape : tuple specifying the shape of the array to be generated.
|
||||
dtype : dtype of the array to be generated.
|
||||
indices_dtype: dtype of the BCOO indicies.
|
||||
indices_dtype: dtype of the BCOO indices.
|
||||
nse : number of specified elements in the matrix, or if 0 < nse < 1, a
|
||||
fraction of sparse dimensions to be specified (default: 0.2).
|
||||
n_batch : number of batch dimensions. must satisfy ``n_batch >= 0`` and
|
||||
|
@ -49,7 +49,7 @@ class SparseInfo(NamedTuple):
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# utilities
|
||||
# TODO: possibly make these primitives, targeting cusparse rountines
|
||||
# TODO: possibly make these primitives, targeting cusparse routines
|
||||
# csr2coo/coo2csr/SPDDMM
|
||||
|
||||
def nfold_vmap(fun, N, *, broadcasted=True, in_axes=0):
|
||||
|
@ -90,7 +90,7 @@ def _write_version(fname: str) -> None:
|
||||
new_version_string = f"_release_version: str = {release_version!r}"
|
||||
fhandle = pathlib.Path(fname)
|
||||
contents = fhandle.read_text()
|
||||
# Expect two occurrances: one above, and one here.
|
||||
# Expect two occurrences: one above, and one here.
|
||||
if contents.count(old_version_string) != 2:
|
||||
raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
|
||||
contents = contents.replace(old_version_string, new_version_string)
|
||||
|
@ -648,7 +648,6 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
numpy_dtype_promotion=['strict', 'standard']
|
||||
)
|
||||
def testSafeToCast(self, input_dtype, output_dtype, numpy_dtype_promotion):
|
||||
print(input_dtype, output_dtype)
|
||||
with jax.numpy_dtype_promotion(numpy_dtype_promotion):
|
||||
# First the special cases which are always safe:
|
||||
always_safe = (
|
||||
|
@ -322,7 +322,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_outer_jit_detects_shard_map_mesh(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
|
||||
_ = jax.jit(f)(jnp.array(2.0)) # doesnt crash
|
||||
_ = jax.jit(f)(jnp.array(2.0)) # doesn't crash
|
||||
|
||||
def test_vmap_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user