Merge pull request #18396 from nouiz:custom_gpu_ops

PiperOrigin-RevId: 580680809
This commit is contained in:
jax authors 2023-11-08 15:39:22 -08:00
commit 21260a7a65
2 changed files with 11 additions and 11 deletions

View File

@ -236,7 +236,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
out_types=[
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
@ -244,7 +244,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
)
).results
return out
@ -277,7 +277,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
)
out = custom_call(
b"rms_backward_affine",
out_types=[
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
@ -286,7 +286,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
)
).results
return out
@ -345,7 +345,7 @@ See [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How
from functools import reduce
from operator import mul
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
def _rms_norm_fwd_abstract(x, weight, eps):
@ -804,7 +804,7 @@ import jax
import jax.numpy as jnp
from build import gpu_ops
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.interpreters import mlir, xla
@ -887,7 +887,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
out_types=[
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
@ -895,7 +895,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
)
).results
return out
@ -928,7 +928,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
)
out = custom_call(
b"rms_backward_affine",
out_types=[
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
@ -937,7 +937,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
)
).results
return out

View File

@ -1453,7 +1453,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{eqn}, got output {ans}") from e
f"{eqn}, got output {ans} from {rule.__name__}") from e
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (