Merge pull request #23465 from frederikwilde:typos

PiperOrigin-RevId: 671545152
This commit is contained in:
jax authors 2024-09-05 16:07:00 -07:00
commit ae400f8d2a
2 changed files with 2 additions and 2 deletions

View File

@ -306,7 +306,7 @@ from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
```
You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`.
You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`.
```{code-cell}
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.

View File

@ -218,7 +218,7 @@ Strict dtype promotion
----------------------
In some contexts it can be useful to disable implicit type promotion behavior, and
instead require all promotions to be explicit. This can be done in JAX by setting the
``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\
``jax_numpy_dtype_promotion`` flag to ``'strict'``. Locally, it can be done with a\
context manager:
.. code-block:: python