mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19472 from mattjj:shmap-tutorial-typos
PiperOrigin-RevId: 600619148
This commit is contained in:
commit
f21022b13b
@ -775,7 +775,7 @@
|
||||
"In deep learning, we might use `all_gather`s on parameters in fully sharded\n",
|
||||
"data parallelism (FSDP).\n",
|
||||
"\n",
|
||||
"# psum_scatter\n",
|
||||
"## `psum_scatter`\n",
|
||||
"\n",
|
||||
"The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n",
|
||||
"`psum` except each function instance gets only one shard of the result:\n",
|
||||
@ -1535,8 +1535,7 @@
|
||||
"\n",
|
||||
"# adapt the loss function to sum the losses across devices\n",
|
||||
"def loss_dp(params, batch):\n",
|
||||
" @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),\n",
|
||||
" check_rep=False) # TODO remove check_rep=False\n",
|
||||
" @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n",
|
||||
" def loss_spmd(local_batch):\n",
|
||||
" inputs, targets = local_batch\n",
|
||||
" predictions = predict(params, inputs) # use reference 'predict`\n",
|
||||
|
@ -562,7 +562,7 @@ def all_gather_ref(_, x_blocks, *, tiled=False):
|
||||
In deep learning, we might use `all_gather`s on parameters in fully sharded
|
||||
data parallelism (FSDP).
|
||||
|
||||
# psum_scatter
|
||||
## `psum_scatter`
|
||||
|
||||
The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like
|
||||
`psum` except each function instance gets only one shard of the result:
|
||||
@ -1081,8 +1081,7 @@ params = jax.device_put(params, NamedSharding(mesh, P()))
|
||||
|
||||
# adapt the loss function to sum the losses across devices
|
||||
def loss_dp(params, batch):
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),
|
||||
check_rep=False) # TODO remove check_rep=False
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
|
||||
def loss_spmd(local_batch):
|
||||
inputs, targets = local_batch
|
||||
predictions = predict(params, inputs) # use reference 'predict`
|
||||
|
Loading…
x
Reference in New Issue
Block a user