mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 00:56:04 +00:00
143 lines
3.0 KiB
Plaintext
143 lines
3.0 KiB
Plaintext
// Copyright 2023 The JAX Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// To regenerate the serialization_generated.py, install flatc (e.g.,
|
|
// from Homebrew) and then:
|
|
//
|
|
// 1. Run flatc --python --gen-onefile serialization.fbs
|
|
// 2. Delete the trailing newlines at the end
|
|
// 3. Add back the licence comment at the start
|
|
//
|
|
|
|
namespace jax_export.serialization;
|
|
|
|
enum PyTreeDefKind: byte {
|
|
leaf = 0,
|
|
none = 1,
|
|
tuple = 2,
|
|
list = 3,
|
|
dict = 4,
|
|
custom = 5,
|
|
}
|
|
|
|
table PyTreeDef {
|
|
kind: PyTreeDefKind;
|
|
children: [PyTreeDef];
|
|
children_names: [string]; // only for "kind==dict"
|
|
custom_name: string; // only for "kind==custom"
|
|
custom_auxdata: [byte]; // only for "kind==custom"
|
|
}
|
|
|
|
enum AbstractValueKind: byte {
|
|
shapedArray = 0,
|
|
abstractToken = 1, // unused
|
|
}
|
|
|
|
enum DType: byte {
|
|
// Last used id: 22
|
|
bool = 0,
|
|
i8 = 1,
|
|
i16 = 2,
|
|
i32 = 3,
|
|
i64 = 4,
|
|
ui8 = 5,
|
|
ui16 = 6,
|
|
ui32 = 7,
|
|
ui64 = 8,
|
|
f0 = 22, // Used in JAX to represent float0
|
|
f16 = 9,
|
|
f32 = 10,
|
|
f64 = 11,
|
|
c64 = 12,
|
|
c128 = 13,
|
|
|
|
bf16 = 14,
|
|
|
|
i4 = 15,
|
|
ui4 = 16,
|
|
|
|
f8_e3m4 = 24,
|
|
f8_e4m3 = 23,
|
|
f8_e4m3b11fnuz = 17,
|
|
f8_e4m3fn = 18,
|
|
f8_e4m3fnuz = 19,
|
|
f8_e5m2 = 20,
|
|
f8_e5m2fnuz = 21,
|
|
f8_e8m0fnu = 25,
|
|
}
|
|
|
|
table AbstractValue {
|
|
kind: AbstractValueKind;
|
|
shape: [string]; // Support shape polymorphism
|
|
dtype: DType;
|
|
}
|
|
|
|
enum ShardingKind: byte {
|
|
unspecified,
|
|
hlo_sharding,
|
|
}
|
|
|
|
table Sharding {
|
|
kind: ShardingKind;
|
|
hlo_sharding_proto: [byte];
|
|
}
|
|
|
|
table Effect {
|
|
type_name: string;
|
|
}
|
|
|
|
enum DisabledSafetyCheckKind: byte {
|
|
platform,
|
|
custom_call,
|
|
shape_assertions, // unused
|
|
}
|
|
|
|
table DisabledSafetyCheck {
|
|
kind: DisabledSafetyCheckKind;
|
|
custom_call_target: string;
|
|
}
|
|
|
|
table Exported {
|
|
/// We increment the serialization version every time we change the
|
|
/// schema, even if the change is backwards compatible.
|
|
/// Note that this field has different semantics and purpose from
|
|
/// `mlir_module_serialization_version`, which encodes
|
|
/// the calling convention of the `mlir_module_serialized`.
|
|
serialization_version: uint16;
|
|
|
|
function_name: string;
|
|
in_tree: PyTreeDef;
|
|
in_avals: [AbstractValue];
|
|
out_tree: PyTreeDef;
|
|
out_avals: [AbstractValue];
|
|
nr_devices: short;
|
|
in_shardings: [Sharding];
|
|
out_shardings: [Sharding];
|
|
|
|
platforms: [string];
|
|
|
|
ordered_effects: [Effect];
|
|
unordered_effects: [Effect];
|
|
disabled_checks: [DisabledSafetyCheck];
|
|
|
|
mlir_module_serialized: [byte];
|
|
calling_convention_version: uint16;
|
|
module_kept_var_idx: [uint16];
|
|
uses_global_constants: bool;
|
|
|
|
vjp: Exported;
|
|
}
|
|
|
|
root_type Exported;
|