6 Commits

Author SHA1 Message Date
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
Parker Schuh
484eb26d2a Redefine compile_and_serialize as serialize(lowered.compile()).
This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.

PiperOrigin-RevId: 518715469
2023-03-22 17:23:19 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
f445c84ba4 Add support for a list of allow_spmd_sharding_propagation_to_output. This gives us more flexibility to tell SPMD which shardings to override.
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
Parker Schuh
91634e0da4 Refactor create_cpp_call to be a method on MeshExecutable
rather then being passed all the way down from pjit.py.

PiperOrigin-RevId: 489353681
2022-11-17 18:05:10 -08:00
Parker Schuh
da765a2e54 Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
2022-11-16 12:54:10 -08:00