Skip to content

Commit

Permalink
[compiler] Set attributes on packed args and on loads from them
Browse files Browse the repository at this point in the history
This commit lets LLVM know that the pointer to the packed argument
structure may not be null, must not be undef/poison, and is
dereferenceable.

It also transfers `noundef` and `nonnull` attributes from the old
parameters to the new loads from the argument struct. Those loads can
take `!noundef` and `!nonnull` metadata.

This should improve performance in certain cases, as this pass typically
runs before the final O3 optimization pipeline and any extra information
we can give LLVM should help.
  • Loading branch information
frasercrmck committed Jan 25, 2024
1 parent b12be1f commit 8754ae0
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 9 deletions.
12 changes: 9 additions & 3 deletions modules/compiler/test/lit/passes/add-kernel-wrapper-dbg.ll
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@
target triple = "spir64-unknown-unknown"
target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"

; CHECK: define internal spir_kernel void @foo(ptr addrspace(1) %x, ptr addrspace(1) %y)
; CHECK: define internal spir_kernel void @foo(ptr addrspace(1) noundef nonnull %x, ptr addrspace(1) %y)
; CHECK-SAME: [[ATTRS:#[0-9]+]] !dbg [[SP:\![0-9]+]] {
define spir_kernel void @foo(ptr addrspace(1) %x, ptr addrspace(1) %y) #0 !dbg !10 {
define spir_kernel void @foo(ptr addrspace(1) noundef nonnull %x, ptr addrspace(1) %y) #0 !dbg !10 {
ret void
}

; CHECK: define spir_kernel void @foo.mux-kernel-wrapper(ptr %packed-args)
; CHECK: define spir_kernel void @foo.mux-kernel-wrapper(
; CHECK-SAME: ptr noundef nonnull dereferenceable(16) %packed-args)
; CHECK-SAME: [[NEW_ATTRS:#[0-9]+]] !dbg [[NEW_SP:\![0-9]+]] !mux_scheduled_fn {{\![0-9]+}} {
; Check that the 'noundef' and 'nonnull' attributes are transferred to the load
; of %x, but not %y
; CHECK: %x = load ptr addrspace(1), ptr {{.*}}, align 1,
; CHECK-SAME: !nonnull [[EMPTY:\![0-9]+]], !noundef [[EMPTY]]
; CHECK: %y = load ptr addrspace(1), ptr {{.*}}, align 1{{$}}
; Check that when we call the original kernel we've attached a debug location.
; This is required by LLVM.
; CHECK: call spir_kernel void @foo({{.*}}) [[ATTRS]], !dbg [[LOC:\![0-9]+]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"

; CHECK: define internal spir_kernel void @[[K_NOINLINE:.*]](ptr addrspace(1) %in, ptr addrspace(1) %out) #[[ATTR_NOINLINE:.+]] {

; CHECK: @add.mux-kernel-wrapper(ptr {{%.*}}) #[[WRAPPER_ATTRS_INLINE:[0-9]+]] !mux_scheduled_fn [[SCHED_MD:\![0-9]+]] {
; CHECK: @add.mux-kernel-wrapper(ptr noundef nonnull dereferenceable(16) %packed-args) #[[WRAPPER_ATTRS_INLINE:[0-9]+]] !mux_scheduled_fn [[SCHED_MD:\![0-9]+]] {
; CHECK: %1 = getelementptr %MuxPackedArgs.add, ptr %packed-args, i32 0, i32 0
; CHECK: %in = load ptr addrspace(1), ptr %1, align 8
; CHECK: %2 = getelementptr %MuxPackedArgs.add, ptr %packed-args, i32 0, i32 1
Expand All @@ -38,7 +38,7 @@ define spir_kernel void @add(i32 addrspace(1)* %in, i32 addrspace(1)* %out) #0 {
ret void
}

; CHECK: @add_noinline.mux-kernel-wrapper(ptr {{%.*}}) #[[WRAPPER_ATTRS_NOINLINE:[0-9]+]] !mux_scheduled_fn [[SCHED_MD]] {
; CHECK: @add_noinline.mux-kernel-wrapper(ptr noundef nonnull dereferenceable(16) {{%.*}}) #[[WRAPPER_ATTRS_NOINLINE:[0-9]+]] !mux_scheduled_fn [[SCHED_MD]] {
; CHECK: call spir_kernel void @[[K_NOINLINE]](
define spir_kernel void @add_noinline(i32 addrspace(1)* %in, i32 addrspace(1)* %out) #1 {
ret void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"
; CHECK: define internal spir_kernel void @add(ptr addrspace(1) readonly %in, ptr addrspace(1) writeonly %out, i8 signext %x, ptr byval(i32) %s) [[ATTRS:#[0-9]+]] !test [[TEST:\![0-9]+]] {

; Check we've copied across all the metadata, and stolen the entry-point metadata
; CHECK: define spir_kernel void @orig.mux-kernel-wrapper(ptr %packed-args) [[WRAPPER_ATTRS:#[0-9]+]] !test [[TEST]] !mux_scheduled_fn [[SCHED_MD:\![0-9]+]] {
; CHECK: define spir_kernel void @orig.mux-kernel-wrapper(ptr noundef nonnull dereferenceable(21) %packed-args) [[WRAPPER_ATTRS:#[0-9]+]] !test [[TEST]] !mux_scheduled_fn [[SCHED_MD:\![0-9]+]] {
; Check we're calling the original kernel with the right attributes
; CHECK: call spir_kernel void @add(ptr addrspace(1) readonly %in, ptr addrspace(1) writeonly %out, i8 signext %x, ptr byval(i32) %s) [[ATTRS]]
define spir_kernel void @add(ptr addrspace(1) readonly %in, ptr addrspace(1) writeonly %out, i8 signext %x, ptr byval(i32) %s) #0 !test !0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"
; Check we've preserved scheduling parameters, their names, and their attributes.
; Check we've dropped !mux_scheduled_fn metadata, which can't be ensured
; correct after this transformation.
; CHECK: define void @add.mux-kernel-wrapper(ptr %packed-args, ptr noalias %wg-info) [[WRAPPER_ATTRS:#[0-9]+]] !mux_scheduled_fn [[WRAPPER_SCHED_PARAMS:\![0-9]+]] {
; CHECK: define void @add.mux-kernel-wrapper(ptr noundef nonnull dereferenceable(12) %packed-args, ptr noalias %wg-info) [[WRAPPER_ATTRS:#[0-9]+]] !mux_scheduled_fn [[WRAPPER_SCHED_PARAMS:\![0-9]+]] {
; Check we're calling the original kernel, passing through the scheduling
; parameters and with the right attributes
; CHECK: call void @add(ptr readonly %in, ptr byval(i32) %s, ptr noalias %wi-info, ptr noalias %wg-info) [[ATTRS]]
Expand Down
46 changes: 46 additions & 0 deletions modules/compiler/test/lit/passes/add-kernel-wrapper.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
; Copyright (C) Codeplay Software Limited
;
; Licensed under the Apache License, Version 2.0 (the "License") with LLVM
; Exceptions; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
; License for the specific language governing permissions and limitations
; under the License.
;
; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

; RUN: muxc --passes 'add-kernel-wrapper<unpacked>,verify' < %s | FileCheck %s

target triple = "spir64-unknown-unknown"
target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"

; CHECK: define internal spir_kernel void @foo(ptr addrspace(1) noundef nonnull %x, ptr addrspace(1) %y)
define spir_kernel void @foo(ptr addrspace(1) noundef nonnull %x, ptr addrspace(1) %y) #0 {
ret void
}

; CHECK: define internal spir_kernel void @empty_args()
define spir_kernel void @empty_args() #0 {
ret void
}

; CHECK: define spir_kernel void @foo.mux-kernel-wrapper(
; CHECK-SAME: ptr noundef nonnull dereferenceable(16) %packed-args)
; Check that the 'noundef' and 'nonnull' attributes are transferred to the load
; of %x, but not %y
; CHECK: %x = load ptr addrspace(1), ptr {{.*}}, align 8,
; CHECK-SAME: !nonnull [[EMPTY:\![0-9]+]], !noundef [[EMPTY]]
; CHECK: %y = load ptr addrspace(1), ptr {{.*}}, align 8{{$}}
; CHECK: call spir_kernel void @foo({{.*}})

; Check we don't add 'nonnull', 'noundef', or 'dereferenceable# attributes to
; this parameter as it may be null, or empty.
; CHECK: define spir_kernel void @empty_args.mux-kernel-wrapper(ptr %packed-args)
; CHECK: call spir_kernel void @empty_args()

attributes #0 = { "mux-kernel"="entry-point" }
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ target datalayout = "e-p:64:64:64-m:e-i64:64-f80:128-n8:16:32:64-S128"

; CHECK: define internal void @add.mux-sched-wrapper(ptr readonly %in, ptr byval(i32) %s, ptr [[WIATTRS:noalias nonnull align 8 dereferenceable\(40\)]] %wi-info, ptr [[WGATTRS:noalias nonnull align 8 dereferenceable\(104\)]] %wg-info) [[SCHED_ATTRS:#[0-9]+]] !mux_scheduled_fn [[SCHED_MD:\![0-9]+]] {

; CHECK: define void @add.mux-kernel-wrapper(ptr %packed-args, ptr [[WGATTRS]] %wg-info) [[WRAPPER_ATTRS:#[0-9]+]] !mux_scheduled_fn [[WRAPPER_MD:\![0-9]+]] {
; CHECK: define void @add.mux-kernel-wrapper(ptr noundef nonnull dereferenceable(12) %packed-args, ptr [[WGATTRS]] %wg-info) [[WRAPPER_ATTRS:#[0-9]+]] !mux_scheduled_fn [[WRAPPER_MD:\![0-9]+]] {
; Check we're initializing the work-item info on the stack
; CHECK: %wi-info = alloca %MuxWorkItemInfo, align 8
; Check we're calling the original kernel, passing through the scheduling
Expand Down
30 changes: 29 additions & 1 deletion modules/compiler/utils/source/add_kernel_wrapper_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ PreservedAnalyses compiler::utils::AddKernelWrapperPass::run(
Module &M, ModuleAnalysisManager &AM) {
bool Changed = false;
SmallPtrSet<Function *, 4> NewKernels;
auto &DL = M.getDataLayout();
auto &BI = AM.getResult<BuiltinInfoAnalysis>(M);

const auto &schedParamInfo = BI.getMuxSchedulingParameters(M);
Expand Down Expand Up @@ -166,6 +167,20 @@ PreservedAnalyses compiler::utils::AddKernelWrapperPass::run(

auto *packedArgPtr = newFunction->getArg(0);
packedArgPtr->setName("packed-args");
// Add some helpful attributes to this argument.
// FIXME: Could we also mandate alignment? Can we guarantee noalias on the
// packed argument structure but not on the pointers it contains?
// If there are no kernel arguments to pack, we don't require the runtime
// to pass a valid pointer: it could be null.
if (!structType->isEmptyTy()) {
// It is invalid for a Mux runtime to pass a null or undef packed argument
// struct.
packedArgPtr->addAttr(Attribute::NoUndef);
packedArgPtr->addAttr(Attribute::NonNull);
// The packed argument struct must be fully dereferenceable.
packedArgPtr->addAttr(Attribute::getWithDereferenceableBytes(
newFunction->getContext(), DL.getTypeAllocSize(structType)));
}

assert(packedArgPtr->getType()->isPointerTy() &&
"First argument should be pointer to the packed args structure");
Expand Down Expand Up @@ -230,7 +245,20 @@ PreservedAnalyses compiler::utils::AddKernelWrapperPass::run(
} else if (arg.hasByValAttr()) {
params.push_back(gep);
} else {
params.push_back(ir.CreateAlignedLoad(type, gep, llvmAlignment));
auto *arg_load = ir.CreateAlignedLoad(type, gep, llvmAlignment);
// If the old argument was marked 'noundef', the result of the load
// from it will also be noundef. Use metadata to convey that.
if (F.getArg(argMapping.OldArgIdx)->hasAttribute(Attribute::NoUndef)) {
MDNode *md = MDNode::get(newFunction->getContext(), std::nullopt);
arg_load->setMetadata(LLVMContext::MD_noundef, md);
}
// If the old argument was marked 'nonnull', the result of the load
// from it will also be nonnull. Use metadata to convey that.
if (F.getArg(argMapping.OldArgIdx)->hasAttribute(Attribute::NonNull)) {
MDNode *md = MDNode::get(newFunction->getContext(), std::nullopt);
arg_load->setMetadata(LLVMContext::MD_nonnull, md);
}
params.push_back(arg_load);
}
// Set the name to help readability
params.back()->setName(arg.getName());
Expand Down

0 comments on commit 8754ae0

Please sign in to comment.