Fix a use-after-free bug in third_party/py/jax/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc

The backing array of the initializer_list is destroyed at the end of the full expression.

PiperOrigin-RevId: 689783482
This commit is contained in:
jax authors 2024-10-25 07:39:35 -07:00
parent 5fd4ea9054
commit 63c1699ed0

View File

@ -118,8 +118,8 @@ class MosaicGpuTest : public ::testing::Test {
};
TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) {
llvm::ArrayRef<int64_t> shape{1, 2, 3};
llvm::ArrayRef<int64_t> slice_shape{1, 2};
std::vector<int64_t> shape{1, 2, 3};
std::vector<int64_t> slice_shape{1, 2};
mlir::LLVM::LLVMPointerType pointer_type =
mlir::LLVM::LLVMPointerType::get(&context_);
@ -128,7 +128,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) {
EXPECT_THAT(
FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type,
memref_type, slice_shape),
memref_type, mlir::ArrayRef<int64_t>(slice_shape)),
StatusIs(
absl::StatusCode::kFailedPrecondition,
HasSubstr(
@ -136,8 +136,8 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) {
}
TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) {
llvm::ArrayRef<int64_t> shape{1, 2, 3};
llvm::ArrayRef<int64_t> slice_shape{1, 2, 3};
std::vector<int64_t> shape{1, 2, 3};
std::vector<int64_t> slice_shape{1, 2, 3};
mlir::LLVM::LLVMPointerType pointer_type =
mlir::LLVM::LLVMPointerType::get(&context_);
@ -145,14 +145,14 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) {
mlir::MemRefType::get(shape, builder_.getI4Type());
EXPECT_THAT(FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type,
memref_type, slice_shape),
memref_type, mlir::ArrayRef<int64_t>(slice_shape)),
StatusIs(absl::StatusCode::kUnimplemented,
HasSubstr("Sub-byte types are not yet supported")));
}
TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) {
llvm::ArrayRef<int64_t> shape{1, 2, 3};
llvm::ArrayRef<int64_t> slice_shape{1, 2, 3};
std::vector<int64_t> shape{1, 2, 3};
std::vector<int64_t> slice_shape{1, 2, 3};
mlir::LLVM::LLVMPointerType pointer_type =
mlir::LLVM::LLVMPointerType::get(&context_);
@ -161,7 +161,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) {
absl::StatusOr<mlir::func::FuncOp> fn_or =
FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type,
memref_type, slice_shape);
memref_type, mlir::ArrayRef<int64_t>(slice_shape));
ASSERT_OK(fn_or);
llvm::SmallVector<mlir::func::CallOp> call_ops =