From 5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Sat, 20 Apr 2024 04:07:57 +1200 Subject: [PATCH] Update packed_stride.hpp to add CUTLASS_HOST_DEVICE decorator to new functions (#1495) --- .../util/include/cutlass/util/packed_stride.hpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index a3ed56a703..e0f2ec0b56 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -111,6 +111,7 @@ make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape // Strides with group mode template +CUTLASS_HOST_DEVICE cute::Stride, cute::Int<0>> make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, @@ -121,6 +122,7 @@ make_cute_packed_stride(cute::Stride, cute::Int<0>> s, } template +CUTLASS_HOST_DEVICE cute::Stride, StrideIntT, cute::Int<0>> make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, @@ -140,6 +142,7 @@ make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, // right in KTRSC order and can be coalesced to just k. // We enforce this condition here with asserts. template +CUTLASS_HOST_DEVICE cute::Stride, cute::Int<0>> make_cute_packed_stride( cute::Stride, cute::Int<0>> s, @@ -169,6 +172,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) template +CUTLASS_HOST_DEVICE cute::Stride, cute::Int<1>> make_cute_packed_stride( cute::Stride, cute::Int<1>> s, @@ -185,6 +189,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) template +CUTLASS_HOST_DEVICE cute::Stride, cute::Int<1>> make_cute_packed_stride( cute::Stride, cute::Int<1>> s, @@ -202,6 +207,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) template +CUTLASS_HOST_DEVICE cute::Stride, cute::Int<1>> make_cute_packed_stride( cute::Stride, cute::Int<1>> s, @@ -224,6 +230,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) template +CUTLASS_HOST_DEVICE cute::Stride, IntT>> make_cute_packed_stride( cute::Stride, IntT>> s, @@ -241,6 +248,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) template +CUTLASS_HOST_DEVICE cute::Stride, IntT, IntT>> make_cute_packed_stride( cute::Stride, IntT, IntT>> s, @@ -260,6 +268,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) template +CUTLASS_HOST_DEVICE cute::Stride, IntT, IntT, IntT>> make_cute_packed_stride( cute::Stride, IntT, IntT, IntT>> s, @@ -286,6 +295,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad // Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad template +CUTLASS_HOST_DEVICE cute::Stride, cute::Stride> make_cute_packed_stride( cute::Stride, cute::Stride> s, @@ -311,6 +321,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad // Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad template +CUTLASS_HOST_DEVICE cute::Stride, cute::Stride> make_cute_packed_stride( cute::Stride, cute::Stride> s, @@ -339,6 +350,7 @@ make_cute_packed_stride( // Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad // Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad template +CUTLASS_HOST_DEVICE cute::Stride, cute::Stride> make_cute_packed_stride( cute::Stride, cute::Stride> s, @@ -370,6 +382,7 @@ make_cute_packed_stride( // cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) template +CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_packed_stride( cute::Stride, IntT> s, @@ -386,6 +399,7 @@ make_cute_packed_stride( // cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) template +CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_packed_stride( cute::Stride, IntT> s, @@ -402,6 +416,7 @@ make_cute_packed_stride( // cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) template +CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_packed_stride( cute::Stride, IntT> s, @@ -424,6 +439,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) template +CUTLASS_HOST_DEVICE cute::Stride, IntT>, cute::Int<0>> make_cute_packed_stride( cute::Stride, IntT>, cute::Int<0>> s, @@ -462,6 +478,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) template +CUTLASS_HOST_DEVICE cute::Stride, IntT, IntT, IntT>, cute::Int<0>> make_cute_packed_stride( cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s,