Merge pull request #24886 from carlosgmartin:fix_typos

PiperOrigin-RevId: 696272953
This commit is contained in:
jax authors 2024-11-13 14:33:38 -08:00
commit 426e13a5aa
3 changed files with 4 additions and 4 deletions

View File

@ -679,7 +679,7 @@ class RmsNormFwdClass:
NamedSharding(mesh, PartitionSpec(None, None)))
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
output_shardings = (arg_shardings[0], invvar_sharding)
# Sharded_impl only accepts positional arugments
# Sharded_impl only accepts positional arguments
# And they should be Jax traceable variables
impl = partial(RmsNormFwdClass.impl, eps=eps)
@ -739,7 +739,7 @@ class RmsNormBwdClass:
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
# Sharded_impl only accepts positional arugments
# Sharded_impl only accepts positional arguments
# And they should be Jax traceable variables
def impl(g, invvar, x, weight):
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(

View File

@ -353,7 +353,7 @@ class RmsNormFwdClass:
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
output_shardings = (arg_shardings[0], invvar_sharding)
# Sharded_impl only accepts positional arugments
# Sharded_impl only accepts positional arguments
# And they should be Jax traceable variables
impl = partial(RmsNormFwdClass.impl, eps=eps)

View File

@ -343,7 +343,7 @@ def pure_callback(
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
is deprecated and it will eventually raise ``NotImplementedError``.
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
the batched arugments, calling ``callback`` once for each batch element.
the batched arguments, calling ``callback`` once for each batch element.
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
added as the leading dimension unbatched inputs.
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the