mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24886 from carlosgmartin:fix_typos
PiperOrigin-RevId: 696272953
This commit is contained in:
commit
426e13a5aa
@ -679,7 +679,7 @@ class RmsNormFwdClass:
|
||||
NamedSharding(mesh, PartitionSpec(None, None)))
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
|
||||
output_shardings = (arg_shardings[0], invvar_sharding)
|
||||
# Sharded_impl only accepts positional arugments
|
||||
# Sharded_impl only accepts positional arguments
|
||||
# And they should be Jax traceable variables
|
||||
impl = partial(RmsNormFwdClass.impl, eps=eps)
|
||||
|
||||
@ -739,7 +739,7 @@ class RmsNormBwdClass:
|
||||
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
|
||||
|
||||
|
||||
# Sharded_impl only accepts positional arugments
|
||||
# Sharded_impl only accepts positional arguments
|
||||
# And they should be Jax traceable variables
|
||||
def impl(g, invvar, x, weight):
|
||||
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
|
||||
|
@ -353,7 +353,7 @@ class RmsNormFwdClass:
|
||||
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
|
||||
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
|
||||
output_shardings = (arg_shardings[0], invvar_sharding)
|
||||
# Sharded_impl only accepts positional arugments
|
||||
# Sharded_impl only accepts positional arguments
|
||||
# And they should be Jax traceable variables
|
||||
impl = partial(RmsNormFwdClass.impl, eps=eps)
|
||||
|
||||
|
@ -343,7 +343,7 @@ def pure_callback(
|
||||
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
|
||||
is deprecated and it will eventually raise ``NotImplementedError``.
|
||||
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
|
||||
the batched arugments, calling ``callback`` once for each batch element.
|
||||
the batched arguments, calling ``callback`` once for each batch element.
|
||||
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
|
||||
added as the leading dimension unbatched inputs.
|
||||
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
|
||||
|
Loading…
x
Reference in New Issue
Block a user