Add print
First version with custom_partitioning. The communication during the gradient aren't optimal.
Fix the gradient sharding
small update
Fix the strange replicated computation.
Make it work with the new JAX version.
Add the structure for custom_p domentation.
Small clean up
First version of the doc
Add comment and typing annotation
tab->space
Simplify code and add docstring
Use the simpler JAX API since 0.4.16 (August 2023).
Custom partitioning using custom_partitioning
updated docs; dump custom_partitioning HLO
doc update
more documentation updates; include links to code instead of inlined code
fix typos
fix more typos
fix type annotations in source and update docs
minor fixes
import fix
lint fix
added apache license header