rocm_jax/jax/_src/export/serialization.fbs
2025-01-22 21:57:43 +00:00

143 lines
3.0 KiB
Plaintext

// Copyright 2023 The JAX Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// To regenerate the serialization_generated.py, install flatc (e.g.,
// from Homebrew) and then:
//
// 1. Run flatc --python --gen-onefile serialization.fbs
// 2. Delete the trailing newlines at the end
// 3. Add back the licence comment at the start
//
namespace jax_export.serialization;
enum PyTreeDefKind: byte {
leaf = 0,
none = 1,
tuple = 2,
list = 3,
dict = 4,
custom = 5,
}
table PyTreeDef {
kind: PyTreeDefKind;
children: [PyTreeDef];
children_names: [string]; // only for "kind==dict"
custom_name: string; // only for "kind==custom"
custom_auxdata: [byte]; // only for "kind==custom"
}
enum AbstractValueKind: byte {
shapedArray = 0,
abstractToken = 1, // unused
}
enum DType: byte {
// Last used id: 22
bool = 0,
i8 = 1,
i16 = 2,
i32 = 3,
i64 = 4,
ui8 = 5,
ui16 = 6,
ui32 = 7,
ui64 = 8,
f0 = 22, // Used in JAX to represent float0
f16 = 9,
f32 = 10,
f64 = 11,
c64 = 12,
c128 = 13,
bf16 = 14,
i4 = 15,
ui4 = 16,
f8_e3m4 = 24,
f8_e4m3 = 23,
f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
}
table AbstractValue {
kind: AbstractValueKind;
shape: [string]; // Support shape polymorphism
dtype: DType;
}
enum ShardingKind: byte {
unspecified,
hlo_sharding,
}
table Sharding {
kind: ShardingKind;
hlo_sharding_proto: [byte];
}
table Effect {
type_name: string;
}
enum DisabledSafetyCheckKind: byte {
platform,
custom_call,
shape_assertions, // unused
}
table DisabledSafetyCheck {
kind: DisabledSafetyCheckKind;
custom_call_target: string;
}
table Exported {
/// We increment the serialization version every time we change the
/// schema, even if the change is backwards compatible.
/// Note that this field has different semantics and purpose from
/// `mlir_module_serialization_version`, which encodes
/// the calling convention of the `mlir_module_serialized`.
serialization_version: uint16;
function_name: string;
in_tree: PyTreeDef;
in_avals: [AbstractValue];
out_tree: PyTreeDef;
out_avals: [AbstractValue];
nr_devices: short;
in_shardings: [Sharding];
out_shardings: [Sharding];
platforms: [string];
ordered_effects: [Effect];
unordered_effects: [Effect];
disabled_checks: [DisabledSafetyCheck];
mlir_module_serialized: [byte];
calling_convention_version: uint16;
module_kept_var_idx: [uint16];
uses_global_constants: bool;
vjp: Exported;
}
root_type Exported;