rocm_jax/jax/experimental/export/serialization_generated.py

807 lines
26 KiB
Python
Raw Normal View History

[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: serialization
import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class PyTreeDefKind(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
leaf = 0
none = 1
tuple = 2
list = 3
dict = 4
class AbstractValueKind(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
shapedArray = 0
abstractToken = 1
class DType(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
bool = 0
i8 = 1
i16 = 2
i32 = 3
i64 = 4
ui8 = 5
ui16 = 6
ui32 = 7
ui64 = 8
f16 = 9
f32 = 10
f64 = 11
c64 = 12
c128 = 13
bf16 = 14
i4 = 15
ui4 = 16
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19
f8_e5m2 = 20
f8_e5m2fnuz = 21
f0 = 22
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
class ShardingKind(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
unspecified = 0
hlo_sharding = 1
class DisabledSafetyCheckKind(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
platform = 0
custom_call = 1
shape_assertions = 2
class PyTreeDef(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = PyTreeDef()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsPyTreeDef(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# PyTreeDef
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# PyTreeDef
def Kind(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0
# PyTreeDef
def Children(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = PyTreeDef()
obj.Init(self._tab.Bytes, x)
return obj
return None
# PyTreeDef
def ChildrenLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# PyTreeDef
def ChildrenIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
# PyTreeDef
def ChildrenNames(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# PyTreeDef
def ChildrenNamesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.VectorLen(o)
return 0
# PyTreeDef
def ChildrenNamesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0
def PyTreeDefStart(builder):
builder.StartObject(3)
def PyTreeDefAddKind(builder, kind):
builder.PrependInt8Slot(0, kind, 0)
def PyTreeDefAddChildren(builder, children):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(children), 0)
def PyTreeDefStartChildrenVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def PyTreeDefAddChildrenNames(builder, childrenNames):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(childrenNames), 0)
def PyTreeDefStartChildrenNamesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def PyTreeDefEnd(builder):
return builder.EndObject()
class AbstractValue(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = AbstractValue()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsAbstractValue(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# AbstractValue
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# AbstractValue
def Kind(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0
# AbstractValue
def Shape(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# AbstractValue
def ShapeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# AbstractValue
def ShapeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
# AbstractValue
def Dtype(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0
def AbstractValueStart(builder):
builder.StartObject(3)
def AbstractValueAddKind(builder, kind):
builder.PrependInt8Slot(0, kind, 0)
def AbstractValueAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
def AbstractValueStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def AbstractValueAddDtype(builder, dtype):
builder.PrependInt8Slot(2, dtype, 0)
def AbstractValueEnd(builder):
return builder.EndObject()
class Sharding(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Sharding()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsSharding(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Sharding
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Sharding
def Kind(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0
# Sharding
def HloShardingProto(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# Sharding
def HloShardingProtoAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# Sharding
def HloShardingProtoLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Sharding
def HloShardingProtoIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
def ShardingStart(builder):
builder.StartObject(2)
def ShardingAddKind(builder, kind):
builder.PrependInt8Slot(0, kind, 0)
def ShardingAddHloShardingProto(builder, hloShardingProto):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(hloShardingProto), 0)
def ShardingStartHloShardingProtoVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
def ShardingEnd(builder):
return builder.EndObject()
class Effect(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Effect()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsEffect(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Effect
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Effect
def TypeName(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
def EffectStart(builder):
builder.StartObject(1)
def EffectAddTypeName(builder, typeName):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(typeName), 0)
def EffectEnd(builder):
return builder.EndObject()
class DisabledSafetyCheck(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = DisabledSafetyCheck()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsDisabledSafetyCheck(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# DisabledSafetyCheck
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# DisabledSafetyCheck
def Kind(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0
# DisabledSafetyCheck
def CustomCallTarget(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
def DisabledSafetyCheckStart(builder):
builder.StartObject(2)
def DisabledSafetyCheckAddKind(builder, kind):
builder.PrependInt8Slot(0, kind, 0)
def DisabledSafetyCheckAddCustomCallTarget(builder, customCallTarget):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(customCallTarget), 0)
def DisabledSafetyCheckEnd(builder):
return builder.EndObject()
class Exported(object):
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Exported()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsExported(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Exported
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# We increment the serialization version every time we change the
# schema, even if the change is backwards compatible.
# Note that this field has different semantics and purpose from
# `mlir_module_serialization_version`, which encodes
# the calling convention of the `mlir_module_serialized`.
[export] Add support for serialization and deserialization of Exported At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process. Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow. Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure). In the process of implementing this we have done some small cleanup of the Exported structure: * renamed serialization_version to mlir_module_serialization_version * renamed disabled_checks to disabled_safety_checks This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export. There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR. PiperOrigin-RevId: 590078785
2023-12-11 23:22:16 -08:00
# Exported
def SerializationVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
return 0
# Exported
def FunctionName(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# Exported
def InTree(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
obj = PyTreeDef()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def InAvals(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = AbstractValue()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def InAvalsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def InAvalsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
return o == 0
# Exported
def OutTree(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
obj = PyTreeDef()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def OutAvals(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = AbstractValue()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def OutAvalsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def OutAvalsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
return o == 0
# Exported
def NrDevices(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int16Flags, o + self._tab.Pos)
return 0
# Exported
def InShardings(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = Sharding()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def InShardingsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def InShardingsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
return o == 0
# Exported
def OutShardings(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = Sharding()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def OutShardingsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def OutShardingsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
return o == 0
# Exported
def LoweringPlatforms(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# Exported
def LoweringPlatformsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def LoweringPlatformsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
return o == 0
# Exported
def OrderedEffects(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = Effect()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def OrderedEffectsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def OrderedEffectsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
return o == 0
# Exported
def UnorderedEffects(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = Effect()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def UnorderedEffectsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def UnorderedEffectsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
return o == 0
# Exported
def DisabledChecks(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = DisabledSafetyCheck()
obj.Init(self._tab.Bytes, x)
return obj
return None
# Exported
def DisabledChecksLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def DisabledChecksIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
return o == 0
# Exported
def MlirModuleSerialized(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# Exported
def MlirModuleSerializedAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# Exported
def MlirModuleSerializedLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def MlirModuleSerializedIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(30))
return o == 0
# Exported
def MlirModuleSerializationVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(32))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
return 0
# Exported
def ModuleKeptVarIdx(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(34))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2))
return 0
# Exported
def ModuleKeptVarIdxAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(34))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint16Flags, o)
return 0
# Exported
def ModuleKeptVarIdxLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(34))
if o != 0:
return self._tab.VectorLen(o)
return 0
# Exported
def ModuleKeptVarIdxIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(34))
return o == 0
# Exported
def UsesShapePolymorphism(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(36))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
# Exported
def Vjp(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(38))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
obj = Exported()
obj.Init(self._tab.Bytes, x)
return obj
return None
def ExportedStart(builder):
builder.StartObject(18)
def ExportedAddSerializationVersion(builder, serializationVersion):
builder.PrependUint16Slot(0, serializationVersion, 0)
def ExportedAddFunctionName(builder, functionName):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(functionName), 0)
def ExportedAddInTree(builder, inTree):
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(inTree), 0)
def ExportedAddInAvals(builder, inAvals):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(inAvals), 0)
def ExportedStartInAvalsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddOutTree(builder, outTree):
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(outTree), 0)
def ExportedAddOutAvals(builder, outAvals):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(outAvals), 0)
def ExportedStartOutAvalsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddNrDevices(builder, nrDevices):
builder.PrependInt16Slot(6, nrDevices, 0)
def ExportedAddInShardings(builder, inShardings):
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(inShardings), 0)
def ExportedStartInShardingsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddOutShardings(builder, outShardings):
builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(outShardings), 0)
def ExportedStartOutShardingsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddLoweringPlatforms(builder, loweringPlatforms):
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(loweringPlatforms), 0)
def ExportedStartLoweringPlatformsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddOrderedEffects(builder, orderedEffects):
builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(orderedEffects), 0)
def ExportedStartOrderedEffectsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddUnorderedEffects(builder, unorderedEffects):
builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(unorderedEffects), 0)
def ExportedStartUnorderedEffectsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddDisabledChecks(builder, disabledChecks):
builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(disabledChecks), 0)
def ExportedStartDisabledChecksVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ExportedAddMlirModuleSerialized(builder, mlirModuleSerialized):
builder.PrependUOffsetTRelativeSlot(13, flatbuffers.number_types.UOffsetTFlags.py_type(mlirModuleSerialized), 0)
def ExportedStartMlirModuleSerializedVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
def ExportedAddMlirModuleSerializationVersion(builder, mlirModuleSerializationVersion):
builder.PrependUint16Slot(14, mlirModuleSerializationVersion, 0)
def ExportedAddModuleKeptVarIdx(builder, moduleKeptVarIdx):
builder.PrependUOffsetTRelativeSlot(15, flatbuffers.number_types.UOffsetTFlags.py_type(moduleKeptVarIdx), 0)
def ExportedStartModuleKeptVarIdxVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
def ExportedAddUsesShapePolymorphism(builder, usesShapePolymorphism):
builder.PrependBoolSlot(16, usesShapePolymorphism, 0)
def ExportedAddVjp(builder, vjp):
builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0)
def ExportedEnd(builder):
return builder.EndObject()