# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Serialization and deserialization of _export.Exported from __future__ import annotations import types from collections.abc import Callable, Sequence from functools import partial from typing import TypeVar try: import flatbuffers except ImportError as e: raise ImportError( "Please install 'flatbuffers' in order to use Exported serialization" ) from e from jax._src import core from jax._src import dtypes from jax._src import effects from jax._src import prng from jax._src import tree_util from jax._src.export import serialization_generated as ser_flatbuf from jax._src.export import _export from jax._src.export import shape_poly from jax._src.lib import xla_client import numpy as np T = TypeVar("T") SerT = TypeVar("SerT") # The _SERIALIZATION_VERSION changes when we change the serialization schema # even if the change is backwards compatible. # Version 1, Nov 2023, first version. # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. # Version 4, April 7th, 2025, adds serialization for PRNGs key types. # This version is backwards compatible with Version 2 and 3. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: """Serializes an Exported. Args: exp: the Exported to serialize. vjp_order: The maximum vjp order to include. E.g., the value 2 means that we serialize the primal functions and two orders of the `vjp` function. This should allow 2nd order reverse mode differentiation of the deserialized function. i.e., `jax.grad(jax.grad(f)).` """ builder = flatbuffers.Builder(65536) exported = _serialize_exported(builder, exp, vjp_order) builder.Finish(exported) return builder.Output() def deserialize(ser: bytearray) -> _export.Exported: """Deserializes an Exported.""" exp = ser_flatbuf.Exported.GetRootAsExported(ser) return _deserialize_exported(exp) def _serialize_exported( builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int ) -> int: # Serialize bottom-up fun_name = builder.CreateString(exp.fun_name) in_tree = _serialize_pytreedef(builder, exp.in_tree) in_avals = _serialize_array(builder, _serialize_aval, exp.in_avals) out_tree = _serialize_pytreedef(builder, exp.out_tree) out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals) in_shardings = _serialize_array( builder, _serialize_sharding, exp.in_shardings_hlo ) out_shardings = _serialize_array( builder, _serialize_sharding, exp.out_shardings_hlo ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects ) unordered_effects = _serialize_array( builder, _serialize_effect, exp.unordered_effects ) disabled_safety_checks = _serialize_array( builder, _serialize_disabled_safety_check, exp.disabled_safety_checks ) platforms = _serialize_array( builder, lambda b, p: b.CreateString(p), exp.platforms ) mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized) module_kept_var_idx = builder.CreateNumpyVector( np.array(exp.module_kept_var_idx, dtype=np.uint16) ) vjp = None if vjp_order > 0: if not exp.has_vjp(): # TODO: add test raise ValueError( "serialization of an Exported that does not have vjps of high-enough " "order" ) vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1) ser_flatbuf.ExportedStart(builder) ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION) ser_flatbuf.ExportedAddFunctionName(builder, fun_name) ser_flatbuf.ExportedAddInTree(builder, in_tree) ser_flatbuf.ExportedAddInAvals(builder, in_avals) ser_flatbuf.ExportedAddOutTree(builder, out_tree) ser_flatbuf.ExportedAddOutAvals(builder, out_avals) ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices) ser_flatbuf.ExportedAddInShardings(builder, in_shardings) ser_flatbuf.ExportedAddOutShardings(builder, out_shardings) ser_flatbuf.ExportedAddPlatforms(builder, platforms) ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects) ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects) ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks) ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized) ser_flatbuf.ExportedAddCallingConventionVersion( builder, exp.calling_convention_version ) ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx) ser_flatbuf.ExportedAddUsesGlobalConstants( builder, exp.uses_global_constants ) if vjp is not None: ser_flatbuf.ExportedAddVjp(builder, vjp) return ser_flatbuf.ExportedEnd(builder) def _serialize_array( builder: flatbuffers.Builder, serialize_one: Callable[[flatbuffers.Builder, T], int], elements: Sequence[T], ) -> int: element_offsets = [serialize_one(builder, e) for e in elements] ser_flatbuf.PyTreeDefStartChildrenVector(builder, len(element_offsets)) for sc in reversed(element_offsets): builder.PrependUOffsetTRelative(sc) return builder.EndVector() def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: serialization_version = exp.SerializationVersion() if serialization_version not in [2, 3]: raise NotImplementedError( f"deserialize unsupported version {serialization_version}" ) fun_name = exp.FunctionName().decode("utf-8") in_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.InTree()) ) scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints deser_aval = partial(_deserialize_aval, scope=scope) in_avals = _deserialize_tuple( exp.InAvalsLength, exp.InAvals, deser_aval ) out_tree = tree_util.tree_structure( _deserialize_pytreedef_to_pytree(exp.OutTree()) ) out_avals = _deserialize_tuple( exp.OutAvalsLength, exp.OutAvals, deser_aval ) nr_devices = exp.NrDevices() in_shardings = _deserialize_tuple( exp.InShardingsLength, exp.InShardings, _deserialize_sharding ) out_shardings = _deserialize_tuple( exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding ) platforms = _deserialize_tuple( exp.PlatformsLength, exp.Platforms, lambda v: v.decode("utf-8"), ) ordered_effects = _deserialize_tuple( exp.OrderedEffectsLength, exp.OrderedEffects, _deserialize_effect ) unordered_effects = _deserialize_tuple( exp.UnorderedEffectsLength, exp.UnorderedEffects, _deserialize_effect ) disabled_safety_checks = _deserialize_tuple( exp.DisabledChecksLength, exp.DisabledChecks, _deserialize_disabled_safety_check, ) mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes() calling_convention_version = exp.CallingConventionVersion() module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist()) uses_global_constants = exp.UsesGlobalConstants() _get_vjp = None if vjp := exp.Vjp(): _get_vjp = lambda _: _deserialize_exported(vjp) return _export.Exported( fun_name=fun_name, in_tree=in_tree, in_avals=in_avals, out_tree=out_tree, out_avals=out_avals, nr_devices=nr_devices, in_shardings_hlo=in_shardings, out_shardings_hlo=out_shardings, platforms=platforms, ordered_effects=ordered_effects, unordered_effects=unordered_effects, disabled_safety_checks=disabled_safety_checks, mlir_module_serialized=mlir_module_serialized, calling_convention_version=calling_convention_version, module_kept_var_idx=module_kept_var_idx, uses_global_constants=uses_global_constants, _get_vjp=_get_vjp, ) def _deserialize_tuple( get_len: Callable[[], int], get_elem: Callable[[int], SerT], deserialize_one: Callable[[SerT], T], ) -> tuple[T, ...]: return tuple(deserialize_one(get_elem(i)) for i in range(get_len())) def _serialize_pytreedef( builder: flatbuffers.Builder, p: tree_util.PyTreeDef ) -> int: node_data = p.node_data() children = p.children() children_vector_offset = None children_names_vector_offset = None if children: children_vector_offset = _serialize_array( builder, _serialize_pytreedef, children ) custom_name = None custom_auxdata = None node_type = node_data and node_data[0] if node_data is None: # leaf kind = ser_flatbuf.PyTreeDefKind.leaf elif node_type is types.NoneType: kind = ser_flatbuf.PyTreeDefKind.none elif node_type is tuple: kind = ser_flatbuf.PyTreeDefKind.tuple elif node_type is list: kind = ser_flatbuf.PyTreeDefKind.list elif node_type is dict: kind = ser_flatbuf.PyTreeDefKind.dict assert len(node_data[1]) == len(children) children_names_vector_offset = _serialize_array( builder, lambda b, s: b.CreateString(s), node_data[1] ) elif node_type in _export.serialization_registry: kind = ser_flatbuf.PyTreeDefKind.custom serialized_name, serialize_auxdata = _export.serialization_registry[node_type] custom_name = builder.CreateString(serialized_name) serialized_auxdata = serialize_auxdata(node_data[1]) if not isinstance(serialized_auxdata, (bytes, bytearray)): raise ValueError( "The custom serialization function for `node_type` must " f"return a `bytes` object. It returned a {type(serialized_auxdata)}.") custom_auxdata = builder.CreateByteVector(serialized_auxdata) else: raise ValueError( "Cannot serialize PyTreeDef containing an " f"unregistered type `{node_type}`. " "Use `export.register_pytree_node_serialization` or " "`export.register_namedtuple_serialization`.") ser_flatbuf.PyTreeDefStart(builder) ser_flatbuf.PyTreeDefAddKind(builder, kind) if children_vector_offset: ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset) if children_names_vector_offset: ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset) if custom_name is not None: ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name) if custom_auxdata is not None: ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata) return ser_flatbuf.PyTreeDefEnd(builder) def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): # We construct a PyTree and later we'll flatten it to get the PyTreeDef. # TODO: is there a more direct way to construct a PyTreeDef? kind = p.Kind() nr_children = p.ChildrenLength() children = [ _deserialize_pytreedef_to_pytree(p.Children(i)) for i in range(nr_children) ] if kind == ser_flatbuf.PyTreeDefKind.leaf: return 0.0 elif kind == ser_flatbuf.PyTreeDefKind.none: return None elif kind == ser_flatbuf.PyTreeDefKind.tuple: return tuple(children) elif kind == ser_flatbuf.PyTreeDefKind.list: return list(children) elif kind == ser_flatbuf.PyTreeDefKind.dict: assert p.ChildrenNamesLength() == nr_children keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)] return dict(zip(keys, children)) elif kind == ser_flatbuf.PyTreeDefKind.custom: serialized_name = p.CustomName().decode("utf-8") if serialized_name not in _export.deserialization_registry: raise ValueError( "Cannot deserialize a PyTreeDef containing an " f"unregistered type `{serialized_name}`. " "Use `export.register_pytree_node_serialization` or " "`export.register_namedtuple_serialization`.") nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name] auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes()) return from_iter(auxdata, children) else: assert False, kind _dtype_to_dtype_kind = { np.dtype("bool"): ser_flatbuf.DType.bool, np.dtype("int8"): ser_flatbuf.DType.i8, np.dtype("int16"): ser_flatbuf.DType.i16, np.dtype("int32"): ser_flatbuf.DType.i32, np.dtype("int64"): ser_flatbuf.DType.i64, np.dtype("uint8"): ser_flatbuf.DType.ui8, np.dtype("uint16"): ser_flatbuf.DType.ui16, np.dtype("uint32"): ser_flatbuf.DType.ui32, np.dtype("uint64"): ser_flatbuf.DType.ui64, dtypes.float0: ser_flatbuf.DType.f0, np.dtype("float16"): ser_flatbuf.DType.f16, np.dtype("float32"): ser_flatbuf.DType.f32, np.dtype("float64"): ser_flatbuf.DType.f64, np.dtype("complex64"): ser_flatbuf.DType.c64, np.dtype("complex128"): ser_flatbuf.DType.c128, dtypes._bfloat16_dtype: ser_flatbuf.DType.bf16, dtypes._int4_dtype: ser_flatbuf.DType.i4, dtypes._uint4_dtype: ser_flatbuf.DType.ui4, dtypes._float8_e4m3b11fnuz_dtype: ser_flatbuf.DType.f8_e4m3b11fnuz, dtypes._float8_e4m3fn_dtype: ser_flatbuf.DType.f8_e4m3fn, dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry, prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg, prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg, } _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } def _serialize_aval( builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: aval_kind = ser_flatbuf.AbstractValueKind.shapedArray shape_offsets = [builder.CreateString(str(d)) for d in aval.shape] ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape)) for d in reversed(shape_offsets): builder.PrependUOffsetTRelative(d) shape_vector_offset = builder.EndVector() ser_flatbuf.AbstractValueStart(builder) ser_flatbuf.AbstractValueAddKind(builder, aval_kind) ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset) ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype]) return ser_flatbuf.AbstractValueEnd(builder) def _deserialize_aval(aval: ser_flatbuf.AbstractValue, scope) -> core.ShapedArray: aval_kind = aval.Kind() if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray: dtype = _dtype_kind_to_dtype[aval.Dtype()] shape = shape_poly.symbolic_shape( ",".join( aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength()) ), scope=scope ) return core.ShapedArray(shape, dtype) else: assert False, aval_kind def _serialize_sharding( builder: flatbuffers.Builder, s: _export.HloSharding | None ) -> int: proto = None if s is None: kind = ser_flatbuf.ShardingKind.unspecified else: kind = ser_flatbuf.ShardingKind.hlo_sharding proto_bytes = s.to_proto().SerializeToString() proto = builder.CreateByteVector(proto_bytes) ser_flatbuf.ShardingStart(builder) ser_flatbuf.ShardingAddKind(builder, kind) if proto is not None: ser_flatbuf.ShardingAddHloShardingProto(builder, proto) return ser_flatbuf.ShardingEnd(builder) def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | None: kind = s.Kind() if kind == ser_flatbuf.ShardingKind.unspecified: return None if kind == ser_flatbuf.ShardingKind.hlo_sharding: proto_str = s.HloShardingProtoAsNumpy().tobytes() proto = xla_client.OpSharding() proto.ParseFromString(proto_str) return xla_client.HloSharding.from_proto(proto) assert False, kind def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int: try: eff_replica = eff.__class__() except Exception: raise NotImplementedError( f"Effect {eff} must have a nullary constructor to be serializable" ) try: hash_eff = hash(eff) hash_eff_replica = hash(eff_replica) except Exception: raise NotImplementedError( f"Effect {eff} must be hashable to be serializable" ) if eff != eff_replica or hash_eff != hash_eff_replica: raise NotImplementedError( f"Effect {eff} must have a nullary class constructor that produces an " "equal effect object." ) effect_type_name = str(eff.__class__) effect_type_name_offset = builder.CreateString(effect_type_name) ser_flatbuf.EffectStart(builder) ser_flatbuf.EffectAddTypeName(builder, effect_type_name_offset) return ser_flatbuf.ExportedEnd(builder) def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect: effect_type_name = eff.TypeName().decode("utf-8") for existing_effect_type in effects.lowerable_effects._effect_types: if str(existing_effect_type) == effect_type_name: try: return existing_effect_type() except: # TODO: add test raise NotImplementedError( f"deserializing effect {effect_type_name} that does not have a " "nullary class constructor" ) raise NotImplementedError( f"cannot deserialize effect type {effect_type_name}" ) def _serialize_disabled_safety_check( builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck ) -> int: custom_call_target_str = check.is_custom_call() custom_call_target = None if custom_call_target_str is not None: kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call custom_call_target = builder.CreateString(custom_call_target_str) elif check == _export.DisabledSafetyCheck.platform(): kind = ser_flatbuf.DisabledSafetyCheckKind.platform else: raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}") ser_flatbuf.DisabledSafetyCheckStart(builder) ser_flatbuf.DisabledSafetyCheckAddKind(builder, kind) if custom_call_target is not None: ser_flatbuf.DisabledSafetyCheckAddCustomCallTarget( builder, custom_call_target ) return ser_flatbuf.DisabledSafetyCheckEnd(builder) def _deserialize_disabled_safety_check( sc: ser_flatbuf.DisabledSafetyCheck, ) -> _export.DisabledSafetyCheck: kind = sc.Kind() if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call: return _export.DisabledSafetyCheck.custom_call( sc.CustomCallTarget().decode("utf-8") ) if kind == ser_flatbuf.DisabledSafetyCheckKind.platform: return _export.DisabledSafetyCheck.platform() if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions: # shape_assertions has been deprecated in June 2024 (turned into a no-op), # and removed in November 2024. We deserialize it to a DisabledSafetyCheck # that has no effect. # TODO(necula): remove this after June 2025, when we should not have any # more serialized artifacts with shape_assertions. return _export.DisabledSafetyCheck.custom_call("no op") assert False, kind