diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 6e9b0a0ad..44cdaf1f1 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -236,7 +236,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): ) out = custom_call( b"rms_forward_affine_mixed_dtype", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, w_type.element_type), ir.RankedTensorType.get((n1,), iv_element_type), ], @@ -244,7 +244,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, w_shape), result_layouts=default_layouts(x_shape, (n1,)), - ) + ).results return out @@ -277,7 +277,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): ) out = custom_call( b"rms_backward_affine", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(w_shape, w_type.element_type), ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), @@ -286,7 +286,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), - ) + ).results return out @@ -345,7 +345,7 @@ See [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How from functools import reduce from operator import mul -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray def _rms_norm_fwd_abstract(x, weight, eps): @@ -804,7 +804,7 @@ import jax import jax.numpy as jnp from build import gpu_ops from jax import core, dtypes -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.experimental.maps import xmap from jax.experimental.pjit import pjit from jax.interpreters import mlir, xla @@ -887,7 +887,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): ) out = custom_call( b"rms_forward_affine_mixed_dtype", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, w_type.element_type), ir.RankedTensorType.get((n1,), iv_element_type), ], @@ -895,7 +895,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, w_shape), result_layouts=default_layouts(x_shape, (n1,)), - ) + ).results return out @@ -928,7 +928,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): ) out = custom_call( b"rms_backward_affine", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(w_shape, w_type.element_type), ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), @@ -937,7 +937,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), - ) + ).results return out diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5cfaff667..dac8e34f6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1453,7 +1453,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " - f"{eqn}, got output {ans}") from e + f"{eqn}, got output {ans} from {rule.__name__}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (