mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #18396 from nouiz:custom_gpu_ops
PiperOrigin-RevId: 580680809
This commit is contained in:
commit
21260a7a65
@ -236,7 +236,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_forward_affine_mixed_dtype",
|
||||
out_types=[
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get((n1,), iv_element_type),
|
||||
],
|
||||
@ -244,7 +244,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, (n1,)),
|
||||
)
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
@ -277,7 +277,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_backward_affine",
|
||||
out_types=[
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, x_type.element_type),
|
||||
ir.RankedTensorType.get(w_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
|
||||
@ -286,7 +286,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
|
||||
)
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
@ -345,7 +345,7 @@ See [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.core import ShapedArray
|
||||
|
||||
|
||||
def _rms_norm_fwd_abstract(x, weight, eps):
|
||||
@ -804,7 +804,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from build import gpu_ops
|
||||
from jax import core, dtypes
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.core import ShapedArray
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.interpreters import mlir, xla
|
||||
@ -887,7 +887,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_forward_affine_mixed_dtype",
|
||||
out_types=[
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get((n1,), iv_element_type),
|
||||
],
|
||||
@ -895,7 +895,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, (n1,)),
|
||||
)
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
@ -928,7 +928,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
)
|
||||
out = custom_call(
|
||||
b"rms_backward_affine",
|
||||
out_types=[
|
||||
result_types=[
|
||||
ir.RankedTensorType.get(x_shape, x_type.element_type),
|
||||
ir.RankedTensorType.get(w_shape, w_type.element_type),
|
||||
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
|
||||
@ -937,7 +937,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
backend_config=opaque,
|
||||
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
|
||||
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
|
||||
)
|
||||
).results
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1453,7 +1453,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{eqn}, got output {ans}") from e
|
||||
f"{eqn}, got output {ans} from {rule.__name__}") from e
|
||||
|
||||
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
|
||||
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
|
||||
|
Loading…
x
Reference in New Issue
Block a user