mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 09:16:08 +00:00

This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities. Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules. This function has two benefits over just building the stablehlo directly: a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults). Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper. PiperOrigin-RevId: 561042402
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from jax._src.interpreters.mlir import (
|
|
AxisContext as AxisContext,
|
|
ConstantHandler as ConstantHandler,
|
|
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
|
|
LoweringResult as LoweringResult,
|
|
LoweringRule as LoweringRule,
|
|
LoweringRuleContext as LoweringRuleContext,
|
|
ModuleContext as ModuleContext,
|
|
RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE,
|
|
SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE,
|
|
Token as Token,
|
|
TokenSet as TokenSet,
|
|
Value as Value,
|
|
_call_lowering as _call_lowering,
|
|
_lowerings as _lowerings,
|
|
_platform_specific_lowerings as _platform_specific_lowerings,
|
|
aval_to_ir_type as aval_to_ir_type,
|
|
aval_to_ir_types as aval_to_ir_types,
|
|
core_call_lowering as core_call_lowering,
|
|
custom_call as custom_call,
|
|
dense_bool_elements as dense_bool_elements,
|
|
dense_int_elements as dense_int_elements,
|
|
dtype_to_ir_type as dtype_to_ir_type,
|
|
emit_python_callback as emit_python_callback,
|
|
flatten_lowering_ir_args as flatten_lowering_ir_args,
|
|
func_dialect as func_dialect,
|
|
hlo as hlo,
|
|
i32_attr as i32_attr,
|
|
i64_attr as i64_attr,
|
|
ir as ir,
|
|
ir_constant as ir_constant,
|
|
ir_constants as ir_constants,
|
|
ir_type_handlers as ir_type_handlers,
|
|
jaxpr_subcomp as jaxpr_subcomp,
|
|
lower_fun as lower_fun,
|
|
lower_jaxpr_to_fun as lower_jaxpr_to_fun,
|
|
lower_jaxpr_to_module as lower_jaxpr_to_module,
|
|
lowerable_effects as lowerable_effects,
|
|
make_ir_context as make_ir_context,
|
|
merge_mlir_modules as merge_mlir_modules,
|
|
module_to_bytecode as module_to_bytecode,
|
|
module_to_string as module_to_string,
|
|
register_constant_handler as register_constant_handler,
|
|
register_lowering as register_lowering,
|
|
shape_tensor as shape_tensor,
|
|
token_type as token_type,
|
|
xla_computation_to_mlir_module as xla_computation_to_mlir_module,
|
|
)
|
|
|
|
from jax._src.mesh import Mesh as Mesh
|
|
from jax._src.sharding_impls import (
|
|
MeshAxisName as MeshAxisName,
|
|
ReplicaAxisContext as ReplicaAxisContext,
|
|
SPMDAxisContext as SPMDAxisContext,
|
|
ShardingContext as ShardingContext,
|
|
)
|