mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 23:06:05 +00:00
522 lines
19 KiB
Python
522 lines
19 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Serialization and deserialization of _export.Exported
|
|
|
|
from __future__ import annotations
|
|
|
|
import types
|
|
from collections.abc import Callable, Sequence
|
|
from functools import partial
|
|
from typing import TypeVar
|
|
|
|
try:
|
|
import flatbuffers
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install 'flatbuffers' in order to use Exported serialization"
|
|
) from e
|
|
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import tree_util
|
|
from jax._src.export import serialization_generated as ser_flatbuf
|
|
from jax._src.export import _export
|
|
from jax._src.export import shape_poly
|
|
from jax._src.lib import xla_client
|
|
|
|
import numpy as np
|
|
|
|
T = TypeVar("T")
|
|
SerT = TypeVar("SerT")
|
|
|
|
# The _SERIALIZATION_VERSION changes when we change the serialization schema
|
|
# even if the change is backwards compatible.
|
|
# Version 1, Nov 2023, first version.
|
|
# Version 2, Dec 16th, 2023, adds the f0 dtype.
|
|
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
|
|
# This version is backwards compatible with Version 2.
|
|
_SERIALIZATION_VERSION = 2
|
|
|
|
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
|
|
"""Serializes an Exported.
|
|
|
|
Args:
|
|
exp: the Exported to serialize.
|
|
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
|
|
serialize the primal functions and two orders of the `vjp` function. This
|
|
should allow 2nd order reverse mode differentiation of the deserialized
|
|
function. i.e., `jax.grad(jax.grad(f)).`
|
|
"""
|
|
builder = flatbuffers.Builder(65536)
|
|
exported = _serialize_exported(builder, exp, vjp_order)
|
|
builder.Finish(exported)
|
|
return builder.Output()
|
|
|
|
|
|
def deserialize(ser: bytearray) -> _export.Exported:
|
|
"""Deserializes an Exported."""
|
|
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
|
|
return _deserialize_exported(exp)
|
|
|
|
|
|
def _serialize_exported(
|
|
builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int
|
|
) -> int:
|
|
# Serialize bottom-up
|
|
fun_name = builder.CreateString(exp.fun_name)
|
|
in_tree = _serialize_pytreedef(builder, exp.in_tree)
|
|
in_avals = _serialize_array(builder, _serialize_aval, exp.in_avals)
|
|
out_tree = _serialize_pytreedef(builder, exp.out_tree)
|
|
out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals)
|
|
in_shardings = _serialize_array(
|
|
builder, _serialize_sharding, exp.in_shardings_hlo
|
|
)
|
|
out_shardings = _serialize_array(
|
|
builder, _serialize_sharding, exp.out_shardings_hlo
|
|
)
|
|
ordered_effects = _serialize_array(
|
|
builder, _serialize_effect, exp.ordered_effects
|
|
)
|
|
unordered_effects = _serialize_array(
|
|
builder, _serialize_effect, exp.unordered_effects
|
|
)
|
|
disabled_safety_checks = _serialize_array(
|
|
builder, _serialize_disabled_safety_check, exp.disabled_safety_checks
|
|
)
|
|
platforms = _serialize_array(
|
|
builder, lambda b, p: b.CreateString(p), exp.platforms
|
|
)
|
|
mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized)
|
|
module_kept_var_idx = builder.CreateNumpyVector(
|
|
np.array(exp.module_kept_var_idx, dtype=np.uint16)
|
|
)
|
|
|
|
vjp = None
|
|
if vjp_order > 0:
|
|
if not exp.has_vjp():
|
|
# TODO: add test
|
|
raise ValueError(
|
|
"serialization of an Exported that does not have vjps of high-enough "
|
|
"order"
|
|
)
|
|
vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1)
|
|
|
|
ser_flatbuf.ExportedStart(builder)
|
|
ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION)
|
|
ser_flatbuf.ExportedAddFunctionName(builder, fun_name)
|
|
ser_flatbuf.ExportedAddInTree(builder, in_tree)
|
|
ser_flatbuf.ExportedAddInAvals(builder, in_avals)
|
|
ser_flatbuf.ExportedAddOutTree(builder, out_tree)
|
|
ser_flatbuf.ExportedAddOutAvals(builder, out_avals)
|
|
ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices)
|
|
ser_flatbuf.ExportedAddInShardings(builder, in_shardings)
|
|
ser_flatbuf.ExportedAddOutShardings(builder, out_shardings)
|
|
ser_flatbuf.ExportedAddPlatforms(builder, platforms)
|
|
ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects)
|
|
ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects)
|
|
ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks)
|
|
ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized)
|
|
ser_flatbuf.ExportedAddCallingConventionVersion(
|
|
builder, exp.calling_convention_version
|
|
)
|
|
ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx)
|
|
ser_flatbuf.ExportedAddUsesGlobalConstants(
|
|
builder, exp.uses_global_constants
|
|
)
|
|
if vjp is not None:
|
|
ser_flatbuf.ExportedAddVjp(builder, vjp)
|
|
return ser_flatbuf.ExportedEnd(builder)
|
|
|
|
|
|
def _serialize_array(
|
|
builder: flatbuffers.Builder,
|
|
serialize_one: Callable[[flatbuffers.Builder, T], int],
|
|
elements: Sequence[T],
|
|
) -> int:
|
|
element_offsets = [serialize_one(builder, e) for e in elements]
|
|
ser_flatbuf.PyTreeDefStartChildrenVector(builder, len(element_offsets))
|
|
for sc in reversed(element_offsets):
|
|
builder.PrependUOffsetTRelative(sc)
|
|
return builder.EndVector()
|
|
|
|
|
|
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
|
serialization_version = exp.SerializationVersion()
|
|
if serialization_version not in [2, 3]:
|
|
raise NotImplementedError(
|
|
f"deserialize unsupported version {serialization_version}"
|
|
)
|
|
|
|
fun_name = exp.FunctionName().decode("utf-8")
|
|
in_tree = tree_util.tree_structure(
|
|
_deserialize_pytreedef_to_pytree(exp.InTree())
|
|
)
|
|
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
|
|
deser_aval = partial(_deserialize_aval, scope=scope)
|
|
in_avals = _deserialize_tuple(
|
|
exp.InAvalsLength, exp.InAvals, deser_aval
|
|
)
|
|
out_tree = tree_util.tree_structure(
|
|
_deserialize_pytreedef_to_pytree(exp.OutTree())
|
|
)
|
|
out_avals = _deserialize_tuple(
|
|
exp.OutAvalsLength, exp.OutAvals, deser_aval
|
|
)
|
|
nr_devices = exp.NrDevices()
|
|
in_shardings = _deserialize_tuple(
|
|
exp.InShardingsLength, exp.InShardings, _deserialize_sharding
|
|
)
|
|
out_shardings = _deserialize_tuple(
|
|
exp.OutShardingsLength, exp.OutShardings, _deserialize_sharding
|
|
)
|
|
platforms = _deserialize_tuple(
|
|
exp.PlatformsLength,
|
|
exp.Platforms,
|
|
lambda v: v.decode("utf-8"),
|
|
)
|
|
ordered_effects = _deserialize_tuple(
|
|
exp.OrderedEffectsLength, exp.OrderedEffects, _deserialize_effect
|
|
)
|
|
unordered_effects = _deserialize_tuple(
|
|
exp.UnorderedEffectsLength, exp.UnorderedEffects, _deserialize_effect
|
|
)
|
|
disabled_safety_checks = _deserialize_tuple(
|
|
exp.DisabledChecksLength,
|
|
exp.DisabledChecks,
|
|
_deserialize_disabled_safety_check,
|
|
)
|
|
|
|
mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes()
|
|
calling_convention_version = exp.CallingConventionVersion()
|
|
module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist())
|
|
uses_global_constants = exp.UsesGlobalConstants()
|
|
|
|
_get_vjp = None
|
|
if vjp := exp.Vjp():
|
|
_get_vjp = lambda _: _deserialize_exported(vjp)
|
|
|
|
return _export.Exported(
|
|
fun_name=fun_name,
|
|
in_tree=in_tree,
|
|
in_avals=in_avals,
|
|
out_tree=out_tree,
|
|
out_avals=out_avals,
|
|
nr_devices=nr_devices,
|
|
in_shardings_hlo=in_shardings,
|
|
out_shardings_hlo=out_shardings,
|
|
platforms=platforms,
|
|
ordered_effects=ordered_effects,
|
|
unordered_effects=unordered_effects,
|
|
disabled_safety_checks=disabled_safety_checks,
|
|
mlir_module_serialized=mlir_module_serialized,
|
|
calling_convention_version=calling_convention_version,
|
|
module_kept_var_idx=module_kept_var_idx,
|
|
uses_global_constants=uses_global_constants,
|
|
_get_vjp=_get_vjp,
|
|
)
|
|
|
|
|
|
def _deserialize_tuple(
|
|
get_len: Callable[[], int],
|
|
get_elem: Callable[[int], SerT],
|
|
deserialize_one: Callable[[SerT], T],
|
|
) -> tuple[T, ...]:
|
|
return tuple(deserialize_one(get_elem(i)) for i in range(get_len()))
|
|
|
|
|
|
def _serialize_pytreedef(
|
|
builder: flatbuffers.Builder, p: tree_util.PyTreeDef
|
|
) -> int:
|
|
node_data = p.node_data()
|
|
children = p.children()
|
|
|
|
children_vector_offset = None
|
|
children_names_vector_offset = None
|
|
if children:
|
|
children_vector_offset = _serialize_array(
|
|
builder, _serialize_pytreedef, children
|
|
)
|
|
custom_name = None
|
|
custom_auxdata = None
|
|
node_type = node_data and node_data[0]
|
|
|
|
if node_data is None: # leaf
|
|
kind = ser_flatbuf.PyTreeDefKind.leaf
|
|
elif node_type is types.NoneType:
|
|
kind = ser_flatbuf.PyTreeDefKind.none
|
|
elif node_type is tuple:
|
|
kind = ser_flatbuf.PyTreeDefKind.tuple
|
|
elif node_type is list:
|
|
kind = ser_flatbuf.PyTreeDefKind.list
|
|
elif node_type is dict:
|
|
kind = ser_flatbuf.PyTreeDefKind.dict
|
|
assert len(node_data[1]) == len(children)
|
|
children_names_vector_offset = _serialize_array(
|
|
builder, lambda b, s: b.CreateString(s), node_data[1]
|
|
)
|
|
elif node_type in _export.serialization_registry:
|
|
kind = ser_flatbuf.PyTreeDefKind.custom
|
|
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
|
|
custom_name = builder.CreateString(serialized_name)
|
|
serialized_auxdata = serialize_auxdata(node_data[1])
|
|
if not isinstance(serialized_auxdata, (bytes, bytearray)):
|
|
raise ValueError(
|
|
"The custom serialization function for `node_type` must "
|
|
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
|
|
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
|
|
else:
|
|
raise ValueError(
|
|
"Cannot serialize PyTreeDef containing an "
|
|
f"unregistered type `{node_type}`. "
|
|
"Use `export.register_pytree_node_serialization` or "
|
|
"`export.register_namedtuple_serialization`.")
|
|
|
|
ser_flatbuf.PyTreeDefStart(builder)
|
|
ser_flatbuf.PyTreeDefAddKind(builder, kind)
|
|
if children_vector_offset:
|
|
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
|
|
if children_names_vector_offset:
|
|
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
|
|
if custom_name is not None:
|
|
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
|
|
if custom_auxdata is not None:
|
|
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
|
|
return ser_flatbuf.PyTreeDefEnd(builder)
|
|
|
|
|
|
def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
|
|
# We construct a PyTree and later we'll flatten it to get the PyTreeDef.
|
|
# TODO: is there a more direct way to construct a PyTreeDef?
|
|
kind = p.Kind()
|
|
nr_children = p.ChildrenLength()
|
|
children = [
|
|
_deserialize_pytreedef_to_pytree(p.Children(i))
|
|
for i in range(nr_children)
|
|
]
|
|
if kind == ser_flatbuf.PyTreeDefKind.leaf:
|
|
return 0.0
|
|
elif kind == ser_flatbuf.PyTreeDefKind.none:
|
|
return None
|
|
elif kind == ser_flatbuf.PyTreeDefKind.tuple:
|
|
return tuple(children)
|
|
elif kind == ser_flatbuf.PyTreeDefKind.list:
|
|
return list(children)
|
|
elif kind == ser_flatbuf.PyTreeDefKind.dict:
|
|
assert p.ChildrenNamesLength() == nr_children
|
|
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
|
|
return dict(zip(keys, children))
|
|
elif kind == ser_flatbuf.PyTreeDefKind.custom:
|
|
serialized_name = p.CustomName().decode("utf-8")
|
|
if serialized_name not in _export.deserialization_registry:
|
|
raise ValueError(
|
|
"Cannot deserialize a PyTreeDef containing an "
|
|
f"unregistered type `{serialized_name}`. "
|
|
"Use `export.register_pytree_node_serialization` or "
|
|
"`export.register_namedtuple_serialization`.")
|
|
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
|
|
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
|
|
return from_iter(auxdata, children)
|
|
else:
|
|
assert False, kind
|
|
|
|
|
|
_dtype_to_dtype_kind = {
|
|
np.dtype("bool"): ser_flatbuf.DType.bool,
|
|
np.dtype("int8"): ser_flatbuf.DType.i8,
|
|
np.dtype("int16"): ser_flatbuf.DType.i16,
|
|
np.dtype("int32"): ser_flatbuf.DType.i32,
|
|
np.dtype("int64"): ser_flatbuf.DType.i64,
|
|
np.dtype("uint8"): ser_flatbuf.DType.ui8,
|
|
np.dtype("uint16"): ser_flatbuf.DType.ui16,
|
|
np.dtype("uint32"): ser_flatbuf.DType.ui32,
|
|
np.dtype("uint64"): ser_flatbuf.DType.ui64,
|
|
dtypes.float0: ser_flatbuf.DType.f0,
|
|
np.dtype("float16"): ser_flatbuf.DType.f16,
|
|
np.dtype("float32"): ser_flatbuf.DType.f32,
|
|
np.dtype("float64"): ser_flatbuf.DType.f64,
|
|
np.dtype("complex64"): ser_flatbuf.DType.c64,
|
|
np.dtype("complex128"): ser_flatbuf.DType.c128,
|
|
dtypes._bfloat16_dtype: ser_flatbuf.DType.bf16,
|
|
dtypes._int4_dtype: ser_flatbuf.DType.i4,
|
|
dtypes._uint4_dtype: ser_flatbuf.DType.ui4,
|
|
dtypes._float8_e4m3b11fnuz_dtype: ser_flatbuf.DType.f8_e4m3b11fnuz,
|
|
dtypes._float8_e4m3fn_dtype: ser_flatbuf.DType.f8_e4m3fn,
|
|
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
|
|
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
|
|
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
|
|
}
|
|
|
|
if dtypes._float8_e3m4_dtype is not None:
|
|
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
|
|
if dtypes._float8_e4m3_dtype is not None:
|
|
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
|
|
|
|
_dtype_kind_to_dtype = {
|
|
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
|
|
}
|
|
|
|
|
|
def _serialize_aval(
|
|
builder: flatbuffers.Builder, aval: core.ShapedArray
|
|
) -> int:
|
|
aval_kind = ser_flatbuf.AbstractValueKind.shapedArray
|
|
shape_offsets = [builder.CreateString(str(d)) for d in aval.shape]
|
|
ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape))
|
|
for d in reversed(shape_offsets):
|
|
builder.PrependUOffsetTRelative(d)
|
|
shape_vector_offset = builder.EndVector()
|
|
|
|
ser_flatbuf.AbstractValueStart(builder)
|
|
ser_flatbuf.AbstractValueAddKind(builder, aval_kind)
|
|
ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset)
|
|
ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype])
|
|
return ser_flatbuf.AbstractValueEnd(builder)
|
|
|
|
|
|
def _deserialize_aval(aval: ser_flatbuf.AbstractValue,
|
|
scope) -> core.ShapedArray:
|
|
aval_kind = aval.Kind()
|
|
if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray:
|
|
dtype = _dtype_kind_to_dtype[aval.Dtype()]
|
|
shape = shape_poly.symbolic_shape(
|
|
",".join(
|
|
aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength())
|
|
),
|
|
scope=scope
|
|
)
|
|
return core.ShapedArray(shape, dtype)
|
|
else:
|
|
assert False, aval_kind
|
|
|
|
|
|
def _serialize_sharding(
|
|
builder: flatbuffers.Builder, s: _export.HloSharding | None
|
|
) -> int:
|
|
proto = None
|
|
if s is None:
|
|
kind = ser_flatbuf.ShardingKind.unspecified
|
|
else:
|
|
kind = ser_flatbuf.ShardingKind.hlo_sharding
|
|
proto_bytes = s.to_proto().SerializeToString()
|
|
proto = builder.CreateByteVector(proto_bytes)
|
|
|
|
ser_flatbuf.ShardingStart(builder)
|
|
ser_flatbuf.ShardingAddKind(builder, kind)
|
|
if proto is not None:
|
|
ser_flatbuf.ShardingAddHloShardingProto(builder, proto)
|
|
return ser_flatbuf.ShardingEnd(builder)
|
|
|
|
|
|
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.HloSharding | None:
|
|
kind = s.Kind()
|
|
if kind == ser_flatbuf.ShardingKind.unspecified:
|
|
return None
|
|
|
|
if kind == ser_flatbuf.ShardingKind.hlo_sharding:
|
|
proto_str = s.HloShardingProtoAsNumpy().tobytes()
|
|
proto = xla_client.OpSharding()
|
|
proto.ParseFromString(proto_str)
|
|
|
|
return xla_client.HloSharding.from_proto(proto)
|
|
|
|
assert False, kind
|
|
|
|
|
|
def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int:
|
|
try:
|
|
eff_replica = eff.__class__()
|
|
except Exception:
|
|
raise NotImplementedError(
|
|
f"Effect {eff} must have a nullary constructor to be serializable"
|
|
)
|
|
try:
|
|
hash_eff = hash(eff)
|
|
hash_eff_replica = hash(eff_replica)
|
|
except Exception:
|
|
raise NotImplementedError(
|
|
f"Effect {eff} must be hashable to be serializable"
|
|
)
|
|
if eff != eff_replica or hash_eff != hash_eff_replica:
|
|
raise NotImplementedError(
|
|
f"Effect {eff} must have a nullary class constructor that produces an "
|
|
"equal effect object."
|
|
)
|
|
effect_type_name = str(eff.__class__)
|
|
effect_type_name_offset = builder.CreateString(effect_type_name)
|
|
ser_flatbuf.EffectStart(builder)
|
|
ser_flatbuf.EffectAddTypeName(builder, effect_type_name_offset)
|
|
return ser_flatbuf.ExportedEnd(builder)
|
|
|
|
|
|
def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect:
|
|
effect_type_name = eff.TypeName().decode("utf-8")
|
|
for existing_effect_type in effects.lowerable_effects._effect_types:
|
|
if str(existing_effect_type) == effect_type_name:
|
|
try:
|
|
return existing_effect_type()
|
|
except:
|
|
# TODO: add test
|
|
raise NotImplementedError(
|
|
f"deserializing effect {effect_type_name} that does not have a "
|
|
"nullary class constructor"
|
|
)
|
|
|
|
raise NotImplementedError(
|
|
f"cannot deserialize effect type {effect_type_name}"
|
|
)
|
|
|
|
|
|
def _serialize_disabled_safety_check(
|
|
builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck
|
|
) -> int:
|
|
custom_call_target_str = check.is_custom_call()
|
|
custom_call_target = None
|
|
if custom_call_target_str is not None:
|
|
kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call
|
|
custom_call_target = builder.CreateString(custom_call_target_str)
|
|
elif check == _export.DisabledSafetyCheck.platform():
|
|
kind = ser_flatbuf.DisabledSafetyCheckKind.platform
|
|
else:
|
|
raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}")
|
|
|
|
ser_flatbuf.DisabledSafetyCheckStart(builder)
|
|
ser_flatbuf.DisabledSafetyCheckAddKind(builder, kind)
|
|
if custom_call_target is not None:
|
|
ser_flatbuf.DisabledSafetyCheckAddCustomCallTarget(
|
|
builder, custom_call_target
|
|
)
|
|
return ser_flatbuf.DisabledSafetyCheckEnd(builder)
|
|
|
|
|
|
def _deserialize_disabled_safety_check(
|
|
sc: ser_flatbuf.DisabledSafetyCheck,
|
|
) -> _export.DisabledSafetyCheck:
|
|
kind = sc.Kind()
|
|
if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call:
|
|
return _export.DisabledSafetyCheck.custom_call(
|
|
sc.CustomCallTarget().decode("utf-8")
|
|
)
|
|
if kind == ser_flatbuf.DisabledSafetyCheckKind.platform:
|
|
return _export.DisabledSafetyCheck.platform()
|
|
if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions:
|
|
# shape_assertions has been deprecated in June 2024 (turned into a no-op),
|
|
# and removed in November 2024. We deserialize it to a DisabledSafetyCheck
|
|
# that has no effect.
|
|
# TODO(necula): remove this after June 2025, when we should not have any
|
|
# more serialized artifacts with shape_assertions.
|
|
return _export.DisabledSafetyCheck.custom_call("no op")
|
|
assert False, kind
|