mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23465 from frederikwilde:typos
PiperOrigin-RevId: 671545152
This commit is contained in:
commit
ae400f8d2a
@ -306,7 +306,7 @@ from jax.interpreters import mlir
|
||||
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
|
||||
```
|
||||
|
||||
You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`.
|
||||
You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`.
|
||||
|
||||
```{code-cell}
|
||||
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
|
||||
|
@ -218,7 +218,7 @@ Strict dtype promotion
|
||||
----------------------
|
||||
In some contexts it can be useful to disable implicit type promotion behavior, and
|
||||
instead require all promotions to be explicit. This can be done in JAX by setting the
|
||||
``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\
|
||||
``jax_numpy_dtype_promotion`` flag to ``'strict'``. Locally, it can be done with a\
|
||||
context manager:
|
||||
|
||||
.. code-block:: python
|
||||
|
Loading…
x
Reference in New Issue
Block a user