Merge pull request #19472 from mattjj:shmap-tutorial-typos

PiperOrigin-RevId: 600619148
This commit is contained in:
jax authors 2024-01-22 17:40:15 -08:00
commit f21022b13b
2 changed files with 4 additions and 6 deletions

View File

@ -775,7 +775,7 @@
"In deep learning, we might use `all_gather`s on parameters in fully sharded\n",
"data parallelism (FSDP).\n",
"\n",
"# psum_scatter\n",
"## `psum_scatter`\n",
"\n",
"The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n",
"`psum` except each function instance gets only one shard of the result:\n",
@ -1535,8 +1535,7 @@
"\n",
"# adapt the loss function to sum the losses across devices\n",
"def loss_dp(params, batch):\n",
" @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),\n",
" check_rep=False) # TODO remove check_rep=False\n",
" @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n",
" def loss_spmd(local_batch):\n",
" inputs, targets = local_batch\n",
" predictions = predict(params, inputs) # use reference 'predict`\n",

View File

@ -562,7 +562,7 @@ def all_gather_ref(_, x_blocks, *, tiled=False):
In deep learning, we might use `all_gather`s on parameters in fully sharded
data parallelism (FSDP).
# psum_scatter
## `psum_scatter`
The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like
`psum` except each function instance gets only one shard of the result:
@ -1081,8 +1081,7 @@ params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P(),
check_rep=False) # TODO remove check_rep=False
@partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
def loss_spmd(local_batch):
inputs, targets = local_batch
predictions = predict(params, inputs) # use reference 'predict`