From d7fb895ee0668d4cdc6f56f1eb40d91792d490ec Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Sun, 5 Nov 2023 08:17:23 -0800 Subject: [PATCH 1/2] Fix the tutorial --- docs/Custom_Operation_for_GPUs.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 6e9b0a0ad..44cdaf1f1 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -236,7 +236,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): ) out = custom_call( b"rms_forward_affine_mixed_dtype", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, w_type.element_type), ir.RankedTensorType.get((n1,), iv_element_type), ], @@ -244,7 +244,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, w_shape), result_layouts=default_layouts(x_shape, (n1,)), - ) + ).results return out @@ -277,7 +277,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): ) out = custom_call( b"rms_backward_affine", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(w_shape, w_type.element_type), ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), @@ -286,7 +286,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), - ) + ).results return out @@ -345,7 +345,7 @@ See [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How from functools import reduce from operator import mul -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray def _rms_norm_fwd_abstract(x, weight, eps): @@ -804,7 +804,7 @@ import jax import jax.numpy as jnp from build import gpu_ops from jax import core, dtypes -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.experimental.maps import xmap from jax.experimental.pjit import pjit from jax.interpreters import mlir, xla @@ -887,7 +887,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): ) out = custom_call( b"rms_forward_affine_mixed_dtype", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, w_type.element_type), ir.RankedTensorType.get((n1,), iv_element_type), ], @@ -895,7 +895,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, w_shape), result_layouts=default_layouts(x_shape, (n1,)), - ) + ).results return out @@ -928,7 +928,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): ) out = custom_call( b"rms_backward_affine", - out_types=[ + result_types=[ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(w_shape, w_type.element_type), ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), @@ -937,7 +937,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): backend_config=opaque, operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), - ) + ).results return out From 1f7df8008ea0a1a2eb27d00be74579314bfd0917 Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Sun, 5 Nov 2023 08:17:33 -0800 Subject: [PATCH 2/2] Better error message. --- jax/_src/interpreters/mlir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5cfaff667..dac8e34f6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1453,7 +1453,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " - f"{eqn}, got output {ans}") from e + f"{eqn}, got output {ans} from {rule.__name__}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (